"transformer_engine/jax/csrc/extensions/pybind.cpp" did not exist on "c473f0e67e1ad748f47263e5a63f6e30f832dc15"
test_common.cu 40.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

Przemek Tredak's avatar
Przemek Tredak committed
24
25
namespace test {

26
27
28
29
30
31
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

48
size_t typeToNumBits(DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
49
50
51
52
53
54
55
56
57
58
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
59
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
64
    {DType::kFloat8E5M2, "float8e5m2"},
65
66
    {DType::kFloat8E8M0, "float8e8m0"},
    {DType::kFloat4E2M1, "float4e2m1"}};
Przemek Tredak's avatar
Przemek Tredak committed
67
68
69
  return name_map.at(type);
}

70
71
72
73
74
75
76
77
78
79
80
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
81
    size_t ret = 1;
82
83
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
84
85
86
87
      ret *= shape.data[i];
    }
    return ret;
}
88

89
90
91
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
92

93
94
95
96
97
98
99
100
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
101

102
103
104
105
106
107
108
109
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

110
111
112
113
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
  return DIVUP(x, y) * y;
}

114
115
116
struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
117
118
119
120
  size_t type_size_bits;
  size_t bytes() const noexcept {
    return (product(shape) * type_size_bits) / 8;
  }
121
122
};

123
124
125
126
size_t bytes(const NVTEShape& shape, const DType type) {
  return (product(shape) * typeToNumBits(type)) / 8;
}

127
128
NVTEShape convertShape(const std::vector<size_t>& s) {
  return nvte_make_shape(s.data(), s.size());
129
130
131
132
133
134
135
136
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
137
    ret.type_size_bits = typeToNumBits(DType::kFloat32);
138
139
140
141
142
143
144
145
146
147
148
149
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

    ret_rowwise.type = DType::kFloat8E8M0;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type = DType::kFloat8E8M0;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
171
    }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    NVTE_CHECK(last_dim % 32 == 0);
    NVTE_CHECK(first_dim % 32 == 0);

    scale_inv_meta ret_rowwise, ret_colwise;

    size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y, scale_dim_X};

    size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};

    ret_rowwise.type = DType::kFloat8E4M3;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
    ret_colwise.type = DType::kFloat8E4M3;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
Przemek Tredak's avatar
Przemek Tredak committed
199
    }
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

215
216
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
217
218
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
219
220
221

    return {ret_rowwise, ret_colwise};
  }
222
223
224
225
226
227
228
229
230
231
232
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
233
234
      size_t scale_dim_0 = DIVUP(first_dim, 128lu);
      size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4;
235
236
237
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
238
239
      size_t scale_dim_0 = DIVUP(last_dim, 128lu);
      size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4;
240
241
242
243
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
244
245
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
246
247
248
249
250
251
252
253
254
255
256
257
258

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
259
260
      size_t scale_dim_0 = DIVUP(last_dim, 128lu);
      size_t scale_dim_1 = DIVUP(first_dim, 4) * 4;
261
262
263
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
264
265
      size_t scale_dim_0 = DIVUP(first_dim, 128lu);
      size_t scale_dim_1 = DIVUP(last_dim, 4) * 4;
266
267
268
269
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
270
271
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
272
273
    return {ret_rowwise, ret_colwise};
  }
274
275
276
277
278
279
280

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
281
282
283
               const NVTEScalingMode &scaling_mode)
  : tensor_(scaling_mode), rowwise_{rowwise}, columnwise_{columnwise}, name_{name} {
  // Initialize RNG
284
285
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
286
287

  // Make sure shape is valid
288
289
290
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
291

292
293
294
295
296
297
298
299
300
301
302
  // Shape after flattening to 2D
  NVTEShape flattened_shape;
  {
    std::vector<size_t> flattened_shape_vec;
    if (shape.ndim > 0) {
      flattened_shape_vec.push_back(product(shape, 0, shape.ndim - 1));
      flattened_shape_vec.push_back(shape.data[shape.ndim - 1]);
    } else {
      flattened_shape_vec.resize(2, 1);
    }
    flattened_shape = convertShape(flattened_shape_vec);
303
  }
304

305
306
307
  // Allocate and initialize data
  void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr;
  const size_t total_size = bytes(shape, type);
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  // Set tensor row-wise data
  if (rowwise) {
    const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
    tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
  }

  // Set tensor column-wise data
  if (columnwise) {
    // Determine shape of column-wise data
    std::vector<size_t> columnwise_shape_vec;
    switch (scaling_mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
    case NVTE_BLOCK_SCALING_1D:
    case NVTE_BLOCK_SCALING_2D: {
      // Column-wise data shape is transposed
      if (shape.ndim > 0) {
        columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
        for (size_t i = 0; i < shape.ndim - 1; ++i) {
          columnwise_shape_vec.emplace_back(shape.data[i]);
        }
      }
      break;
    }
    case NVTE_MXFP8_1D_SCALING:
    case NVTE_NVFP4_1D_SCALING: {
      // Column-wise data matches shape
      for (size_t i = 0; i < shape.ndim; ++i) {
        columnwise_shape_vec.emplace_back(shape.data[i]);
      }
      break;
    }
    default:
      NVTE_ERROR("Unrecognized scaling mode (", (size_t)scaling_mode, ").");
    }
    const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(),
                                                  columnwise_shape_vec.size());

    // Set column-wise data buffer
    const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
    tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
  }
364

365
366
367
  // Configure scales, amaxes, and other tensor buffers
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
368
  if (isFp8Type(type) || isFp4Type(type)) {
369
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
370
371
372
373
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
374
375
376
377
378
379
380
381
382
383
384
385
386
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
387
                                          std::vector<size_t>{1});
388
389
390
391
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
392
393
394
395
396
397
398
      if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
        // Used for NVFP4 second stage scaling
        cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
        cudaMemset(scale, 0, sizeof(float));
        scale_cpu_data_ = std::make_shared<float>(0);
        tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      }
399
      auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode());
400
401
      auto rowwise_scale_size = rowwise_scale_meta.bytes();
      auto columnwise_scale_size = colwise_scale_meta.bytes();
402
403
404
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
405
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
406
407
408
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
409
410
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
411
412
413
414
415
416
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
417
418
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
419
      }
420
    }
421
  }
Przemek Tredak's avatar
Przemek Tredak committed
422
423
424
425
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
426
  const size_t size = bytes(s, tensor_.dtype());
427
428
429
430
431
432
433
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
434
435
436
    const DType colwise_type = tensor_.dtype();

    const size_t colwise_size = bytes(s, colwise_type);
437
    cudaMemcpy(cpu_data_columnwise_.get(),
438
439
440
                tensor_.get_columnwise_data().data_ptr,
                colwise_size,
                cudaMemcpyDeviceToHost);
441
  }
442
443
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
444
445
446
447
448
449
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
450
451
452
453
454
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
455
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
456
    if (rowwise_) {
457
      auto scale_size = rowwise_scale_meta.bytes();
458
459
460
461
462
463
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
464
      auto scale_size = colwise_scale_meta.bytes();
465
466
467
468
469
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
470
  }
Przemek Tredak's avatar
Przemek Tredak committed
471
472
473
474
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
475
  const size_t size = bytes(s, tensor_.dtype());
476
  if (rowwise_) {
477
478
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
               cudaMemcpyHostToDevice);
479
480
  }
  if (columnwise_) {
481
482
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
               cudaMemcpyHostToDevice);
483
  }
484
485
486
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
        || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
487
      if (tensor_.amax() != nullptr){
488
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
489
      }
490
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
491
    }
492
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
493
    if (rowwise_) {
494
      auto scale_size = rowwise_scale_meta.bytes();
495
496
497
498
499
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
500
      auto scale_size = colwise_scale_meta.bytes();
501
502
503
504
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
505
506
507
508
  }
}

void Tensor::set_scale(float scale) {
509
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
510
    NVTE_CHECK(scale_cpu_data_);
511
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
512
513
514
      *scale_cpu_data_ = scale;
      from_cpu();
    }
515
516
517
518
  }
}

void Tensor::set_scale_inv(float scale_inv) {
519
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
520
521
522
523
524
525
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
526

527
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
528
529
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
530
      if (num_scales == 1) {
531
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
532
      } else {
533
        std::uniform_int_distribution<uint8_t> dis(0, 127);
534
535
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
536
537
538
539
540
541
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
542
      if (num_scales == 1) {
543
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
544
      } else {
545
        std::uniform_int_distribution<uint8_t> dis(0, 127);
546
547
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
548
549
550
551
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
552
553
554
555
556
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
557
558
  if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
      || isFp4Type(dtype()) && isFp4Type(other.dtype())) {
559
560
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
561
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
562
563
564
565
566
567
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
568
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
569
570
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
571
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
572
573
574
575
576
577
578
579
580
581
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
582
583
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
602
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
603
604
605
606
607
608
609
610
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

611
612
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
613
614
                               double atol, double rtol, bool if_on_gpus,
                               const size_t tolerable_mismatches_limit) {
615
616
617
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
618
619
  size_t mismatches_num = 0;
  int first_mismatch_idx = -1;
Przemek Tredak's avatar
Przemek Tredak committed
620
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
621
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
      }
639
      std::string direction = rowwise ? "rowwise" : "columnwise";
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
      if (assertion) {
        mismatches_num++;
        if (first_mismatch_idx == -1) {
          first_mismatch_idx = i;
        }
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
        const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);

        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                    << tolerable_mismatches_limit << "." << std::endl
                    << "Error in tensor " << name << " in "
                    << direction << " direction." << std::endl
                     << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
                     << " (" << std::to_string(first_mismatch_idx) << "): "
                     << first_mismatch_t << " vs " << first_mismatch_r;
      }
658
659
660
661
662
663
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
664
665
                                  const size_t N, const double atol, const double rtol,
                                  size_t& mismatches) {
666
667
  int first_mismatch_idx = N;

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
  #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
  {
    size_t thread_mismatches = 0;
    #pragma omp for schedule(static)
    for (size_t i = 0; i < N; ++i) {
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);

      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && (data_type == DType::kFloat32);
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
      }
      if (assertion) {
        if (i < first_mismatch_idx) {
          first_mismatch_idx = i;
        }
        thread_mismatches++;
      }
695
    }
696
    mismatches += thread_mismatches;
697
698
699
700
701
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
702
703
                             const bool rowwise, double atol, double rtol, bool if_on_gpus,
                             const size_t tolerable_mismatches_limit) {
704
705
706
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
707
  size_t mismatches = 0;
708
709
710
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
711

712
713
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
    if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
714
715
716
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
      std::string direction = rowwise ? "rowwise" : "columnwise";
717
718
719
720
721
722
723

      GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
                   << tolerable_mismatches_limit << "." << std::endl
                   << "Error in tensor " << name << " in "
                   << direction << " direction." << std::endl
                   << "Mismatch at place " << to_string(unravel(i, shape))
                   << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
724
725
726
727
    }
  );
}

728
void compareResults(const std::string &name, const Tensor &test, const void *ref,
729
730
                    const bool rowwise, double atol, double rtol, bool if_on_gpus,
                    const size_t tolerable_mismatches_limit) {
731
732
  constexpr bool sequential = false;
  if constexpr (sequential) {
733
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
734
  } else {
735
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
736
737
738
  }
}

739
740
741
742
743
744
745
746
747
748
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
template <typename T>
struct CastToType;

template <>
struct CastToType<uint8_t> {
  using type = int;
};

template <>
struct CastToType<fp8e4m3> {
  using type = float;
};

template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
                             const size_t row_blocks, const size_t col_blocks, const size_t stride,
                             size_t& mismatches_num, const size_t atol,
                             const double abs_tolerable_mismatches_limit,
                             const double rel_tolerable_mismatches_limit)
790
{
791
792
793
794
  using UpcastType = typename CastToType<T>::type;
  auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);


795
796
797
798
799
800
  const size_t N = row_blocks * col_blocks;
  const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
                                                     std::floor(N * rel_tolerable_mismatches_limit));
  mismatches_num = 0;
  std::vector<int> mismatch_indices;

801
802
803
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
804
805
806
      float t, r;

      bool assertion = false;
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
      if (std::is_same<T, uint8_t>::value) {
        t = static_cast<float>(test[idx]);
        r = static_cast<float>(ref[idx]);
        assertion = std::abs(t - r) > atol;
      } else {
        t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
        r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
        const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
                              && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
        if (mismatch) {
          /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
          const double mean = (t + r) / 2;
          const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
          const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
          const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
          const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
          assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
        }
      }
      if (assertion) {
829
830
831
832
833
834
835
        mismatches_num++;
        mismatch_indices.push_back(idx);
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        std::cout << "Error in " << name << std::endl;
        for (const int index : mismatch_indices) {
          std::cout << "Mismatch at (" << index << "):"
836
837
                    << static_cast<UpcastType>(test[index]) << " vs "
                    << static_cast<UpcastType>(ref[index]) << std::endl;
838
839
840
841
842
        }
        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                     << tolerable_mismatches_limit << ".";
      }
    }
843
844
845
  }
}

846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);

template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);


Przemek Tredak's avatar
Przemek Tredak committed
862
863
864
865
866
867
868
869
870
871
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
872
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
873
874
875
876
877
878
879
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

880
881
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
882
883
884
885
886
887
888
889
890
891
892
893
894
  // Check how many RNG calls are required to generate one uniform random value
  int rng_calls_per_val = 0;
  {
    std::mt19937 gen1 = *gen, gen2 = *gen;
    std::uniform_real_distribution<> dis(-2.0, 1.0);
    const float _ = dis(gen1);
    while (gen2 != gen1) {
      auto _ = gen2();
      ++rng_calls_per_val;
    }
  }

  // Generate uniform random values in parallel
895
896
897
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
898
899
900
901
902
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
903
    gen_local.discard(idx_min * rng_calls_per_val);
904
    std::uniform_real_distribution<> dis(-2.0, 1.0);
905
906

    for (int i = idx_min; i < idx_max; ++i) {
907
908
909
      data[i] = static_cast<T>(dis(gen_local));
    }
  }
910
  gen->discard(size * rng_calls_per_val);
911
912
}

913
void fillUniform(Tensor *t) {
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
931
  std::uniform_real_distribution<> dis(-2.0, 1.0);
932
933
934
935
936
937
938
939
940
941
942
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
943
      for (size_t i = 0; i < size; ++i) {
944
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
945
      }
946
947
948
    });
  } else {
    double minAbs = -2.0;
949
    double maxAbs = 1.0;
950
951
952
953
954
955
956
957
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
958
959
960
961
962
      for (size_t idx = 0; idx < size; ++idx) {
        const bool is_negative = (dis_sign(t->gen()) < 0.0);
        double val = dis(t->gen());
        if (is_negative) {
          val = -val;
963
        }
964
        data[idx] = static_cast<InputType>(val);
965
966
967
968
      }
    });
  }
  t->set_scale_inv(1.0);
969
970
971
  t->from_cpu();
}

972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

988
989
990
991
992
993
994
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
995
996
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
997
998
999
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
1000

1001
1002
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
1003
  const float scale = dis(t->gen());
1004
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
1005
1006
}

1007
1008
1009
1010
1011
1012
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

1013
bool isFp8Type(DType type) {
1014
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
1015
1016
}

1017
1018
1019
1020
bool isFp4Type(DType type) {
  return type == DType::kFloat4E2M1;
}

1021
1022
1023
1024
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
1042
1043
    const bool is_rowwise = (block_size_rows == 1)
                            && ((block_size_cols == 32) || (block_size_cols == 16));
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
1058
1059
}

Przemek Tredak's avatar
Przemek Tredak committed
1060
}  // namespace test