test_common.cu 39.8 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

Przemek Tredak's avatar
Przemek Tredak committed
24
25
namespace test {

26
27
28
29
30
31
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

48
size_t typeToNumBits(DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
49
50
51
52
53
54
55
56
57
58
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
59
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
64
    {DType::kFloat8E5M2, "float8e5m2"},
65
66
    {DType::kFloat8E8M0, "float8e8m0"},
    {DType::kFloat4E2M1, "float4e2m1"}};
Przemek Tredak's avatar
Przemek Tredak committed
67
68
69
  return name_map.at(type);
}

70
71
72
73
74
75
76
77
78
79
80
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
81
    size_t ret = 1;
82
83
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
84
85
86
87
      ret *= shape.data[i];
    }
    return ret;
}
88

89
90
91
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
92

93
94
95
96
97
98
99
100
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
101

102
103
104
105
106
107
108
109
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

110
111
112
113
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
  return DIVUP(x, y) * y;
}

114
115
116
struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
117
118
119
120
  size_t type_size_bits;
  size_t bytes() const noexcept {
    return (product(shape) * type_size_bits) / 8;
  }
121
122
};

123
124
125
126
size_t bytes(const NVTEShape& shape, const DType type) {
  return (product(shape) * typeToNumBits(type)) / 8;
}

127
128
NVTEShape convertShape(const std::vector<size_t>& s) {
  return nvte_make_shape(s.data(), s.size());
129
130
131
132
133
134
135
136
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
137
    ret.type_size_bits = typeToNumBits(DType::kFloat32);
138
139
140
141
142
143
144
145
146
147
148
149
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

    ret_rowwise.type = DType::kFloat8E8M0;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type = DType::kFloat8E8M0;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
171
    }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    NVTE_CHECK(last_dim % 32 == 0);
    NVTE_CHECK(first_dim % 32 == 0);

    scale_inv_meta ret_rowwise, ret_colwise;

    size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y, scale_dim_X};

    size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};

    ret_rowwise.type = DType::kFloat8E4M3;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
    ret_colwise.type = DType::kFloat8E4M3;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
Przemek Tredak's avatar
Przemek Tredak committed
199
    }
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

215
216
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
217
218
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
219
220
221

    return {ret_rowwise, ret_colwise};
  }
222
223
224
225
226
227
228
229
230
231
232
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
233
234
      size_t scale_dim_0 = DIVUP(first_dim, 128lu);
      size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4;
235
236
237
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
238
239
      size_t scale_dim_0 = DIVUP(last_dim, 128lu);
      size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4;
240
241
242
243
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
244
245
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
246
247
248
249
250
251
252
253
254
255
256
257
258

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
259
260
      size_t scale_dim_0 = DIVUP(last_dim, 128lu);
      size_t scale_dim_1 = DIVUP(first_dim, 4) * 4;
261
262
263
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
264
265
      size_t scale_dim_0 = DIVUP(first_dim, 128lu);
      size_t scale_dim_1 = DIVUP(last_dim, 4) * 4;
266
267
268
269
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
270
271
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
272
273
    return {ret_rowwise, ret_colwise};
  }
274
275
276
277
278
279
280

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
281
               const NVTEScalingMode &scaling_mode) {
282
283
284
285
286
  name_ = name;
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
  rowwise_ = rowwise;
  columnwise_ = columnwise;
287
  size_t total_size = bytes(shape, type);
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
  void *dptr_rowwise = nullptr;
  void *dptr_columnwise = nullptr;
  cpu_data_rowwise_ = nullptr;
  cpu_data_columnwise_ = nullptr;
  amax_cpu_data_ = nullptr;
  scale_cpu_data_ = nullptr;
  rowwise_scale_inv_cpu_data_ = nullptr;
  columnwise_scale_inv_cpu_data_ = nullptr;
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
  std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
                                            shape.data[shape.ndim - 1]};
  NVTEShape normalized_shape = convertShape(normalized_shape_v);
304
  NVTEShape columnwise_shape = {};
305
306

  std::vector<size_t> columnwise_shape_vec;
307
308
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
      || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
309
310
311
312
313
314
    // Transpose when tensor scaling
    columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
    for (size_t i = 0; i < shape.ndim - 1; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  } else {
315
    // Same shape for MX and NVFP4
316
317
318
319
    for (size_t i = 0; i < shape.ndim; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  }
320
321

  if (columnwise) {
322
    columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
323
  }
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

  tensor_ = TensorWrapper(scaling_mode);

  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }

342
343
344
345
346
347
  const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
  const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
  tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
  tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);

  if (isFp8Type(type) || isFp4Type(type)) {
348
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
349
350
351
352
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
353
354
355
356
357
358
359
360
361
362
363
364
365
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
366
                                          std::vector<size_t>{1});
367
368
369
370
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
371
372
373
374
375
376
377
378
      if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
        // Used for NVFP4 second stage scaling
        cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
        cudaMemset(scale, 0, sizeof(float));
        scale_cpu_data_ = std::make_shared<float>(0);
        tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      }
      auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
379
380
      auto rowwise_scale_size = rowwise_scale_meta.bytes();
      auto columnwise_scale_size = colwise_scale_meta.bytes();
381
382
383
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
384
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
385
386
387
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
388
389
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
390
391
392
393
394
395
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
396
397
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
398
      }
399
    }
400
  }
Przemek Tredak's avatar
Przemek Tredak committed
401
402
403
404
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
405
  const size_t size = bytes(s, tensor_.dtype());
406
407
408
409
410
411
412
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
413
414
415
    const DType colwise_type = tensor_.dtype();

    const size_t colwise_size = bytes(s, colwise_type);
416
    cudaMemcpy(cpu_data_columnwise_.get(),
417
418
419
                tensor_.get_columnwise_data().data_ptr,
                colwise_size,
                cudaMemcpyDeviceToHost);
420
  }
421
422
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
423
424
425
426
427
428
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
429
430
431
432
433
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
434
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
435
    if (rowwise_) {
436
      auto scale_size = rowwise_scale_meta.bytes();
437
438
439
440
441
442
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
443
      auto scale_size = colwise_scale_meta.bytes();
444
445
446
447
448
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
449
  }
Przemek Tredak's avatar
Przemek Tredak committed
450
451
452
453
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
454
  const size_t size = bytes(s, tensor_.dtype());
455
  if (rowwise_) {
456
457
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
               cudaMemcpyHostToDevice);
458
459
  }
  if (columnwise_) {
460
461
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
               cudaMemcpyHostToDevice);
462
  }
463
464
465
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
        || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
466
      if (tensor_.amax() != nullptr){
467
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
468
      }
469
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
470
    }
471
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
472
    if (rowwise_) {
473
      auto scale_size = rowwise_scale_meta.bytes();
474
475
476
477
478
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
479
      auto scale_size = colwise_scale_meta.bytes();
480
481
482
483
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
484
485
486
487
  }
}

void Tensor::set_scale(float scale) {
488
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
489
    NVTE_CHECK(scale_cpu_data_);
490
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
491
492
493
      *scale_cpu_data_ = scale;
      from_cpu();
    }
494
495
496
497
  }
}

void Tensor::set_scale_inv(float scale_inv) {
498
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
499
500
501
502
503
504
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
505

506
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
507
508
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
509
      if (num_scales == 1) {
510
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
511
      } else {
512
        std::uniform_int_distribution<uint8_t> dis(0, 127);
513
514
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
515
516
517
518
519
520
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
521
      if (num_scales == 1) {
522
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
523
      } else {
524
        std::uniform_int_distribution<uint8_t> dis(0, 127);
525
526
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
527
528
529
530
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
531
532
533
534
535
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
536
537
  if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
      || isFp4Type(dtype()) && isFp4Type(other.dtype())) {
538
539
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
540
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
541
542
543
544
545
546
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
547
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
548
549
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
550
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
551
552
553
554
555
556
557
558
559
560
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
561
562
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
581
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
582
583
584
585
586
587
588
589
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

590
591
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
592
593
                               double atol, double rtol, bool if_on_gpus,
                               const size_t tolerable_mismatches_limit) {
594
595
596
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
597
598
  size_t mismatches_num = 0;
  int first_mismatch_idx = -1;
Przemek Tredak's avatar
Przemek Tredak committed
599
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
600
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
      }
618
      std::string direction = rowwise ? "rowwise" : "columnwise";
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
      if (assertion) {
        mismatches_num++;
        if (first_mismatch_idx == -1) {
          first_mismatch_idx = i;
        }
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
        const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);

        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                    << tolerable_mismatches_limit << "." << std::endl
                    << "Error in tensor " << name << " in "
                    << direction << " direction." << std::endl
                     << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
                     << " (" << std::to_string(first_mismatch_idx) << "): "
                     << first_mismatch_t << " vs " << first_mismatch_r;
      }
637
638
639
640
641
642
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
643
644
                                  const size_t N, const double atol, const double rtol,
                                  size_t& mismatches) {
645
646
  int first_mismatch_idx = N;

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
  #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
  {
    size_t thread_mismatches = 0;
    #pragma omp for schedule(static)
    for (size_t i = 0; i < N; ++i) {
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);

      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && (data_type == DType::kFloat32);
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
      }
      if (assertion) {
        if (i < first_mismatch_idx) {
          first_mismatch_idx = i;
        }
        thread_mismatches++;
      }
674
    }
675
    mismatches += thread_mismatches;
676
677
678
679
680
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
681
682
                             const bool rowwise, double atol, double rtol, bool if_on_gpus,
                             const size_t tolerable_mismatches_limit) {
683
684
685
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
686
  size_t mismatches = 0;
687
688
689
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
690

691
692
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
    if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
693
694
695
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
      std::string direction = rowwise ? "rowwise" : "columnwise";
696
697
698
699
700
701
702

      GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
                   << tolerable_mismatches_limit << "." << std::endl
                   << "Error in tensor " << name << " in "
                   << direction << " direction." << std::endl
                   << "Mismatch at place " << to_string(unravel(i, shape))
                   << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
703
704
705
706
    }
  );
}

707
void compareResults(const std::string &name, const Tensor &test, const void *ref,
708
709
                    const bool rowwise, double atol, double rtol, bool if_on_gpus,
                    const size_t tolerable_mismatches_limit) {
710
711
  constexpr bool sequential = false;
  if constexpr (sequential) {
712
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
713
  } else {
714
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
715
716
717
  }
}

718
719
720
721
722
723
724
725
726
727
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
template <typename T>
struct CastToType;

template <>
struct CastToType<uint8_t> {
  using type = int;
};

template <>
struct CastToType<fp8e4m3> {
  using type = float;
};

template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
                             const size_t row_blocks, const size_t col_blocks, const size_t stride,
                             size_t& mismatches_num, const size_t atol,
                             const double abs_tolerable_mismatches_limit,
                             const double rel_tolerable_mismatches_limit)
769
{
770
771
772
773
  using UpcastType = typename CastToType<T>::type;
  auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);


774
775
776
777
778
779
  const size_t N = row_blocks * col_blocks;
  const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
                                                     std::floor(N * rel_tolerable_mismatches_limit));
  mismatches_num = 0;
  std::vector<int> mismatch_indices;

780
781
782
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
783
784
785
      float t, r;

      bool assertion = false;
786

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
      if (std::is_same<T, uint8_t>::value) {
        t = static_cast<float>(test[idx]);
        r = static_cast<float>(ref[idx]);
        assertion = std::abs(t - r) > atol;
      } else {
        t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
        r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
        const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
                              && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
        if (mismatch) {
          /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
          const double mean = (t + r) / 2;
          const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
          const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
          const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
          const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
          assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
        }
      }
      if (assertion) {
808
809
810
811
812
813
814
        mismatches_num++;
        mismatch_indices.push_back(idx);
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        std::cout << "Error in " << name << std::endl;
        for (const int index : mismatch_indices) {
          std::cout << "Mismatch at (" << index << "):"
815
816
                    << static_cast<UpcastType>(test[index]) << " vs "
                    << static_cast<UpcastType>(ref[index]) << std::endl;
817
818
819
820
821
        }
        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                     << tolerable_mismatches_limit << ".";
      }
    }
822
823
824
  }
}

825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);

template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);


Przemek Tredak's avatar
Przemek Tredak committed
841
842
843
844
845
846
847
848
849
850
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
851
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
852
853
854
855
856
857
858
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

859
860
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
861
862
863
864
865
866
867
868
869
870
871
872
873
  // Check how many RNG calls are required to generate one uniform random value
  int rng_calls_per_val = 0;
  {
    std::mt19937 gen1 = *gen, gen2 = *gen;
    std::uniform_real_distribution<> dis(-2.0, 1.0);
    const float _ = dis(gen1);
    while (gen2 != gen1) {
      auto _ = gen2();
      ++rng_calls_per_val;
    }
  }

  // Generate uniform random values in parallel
874
875
876
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
877
878
879
880
881
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
882
    gen_local.discard(idx_min * rng_calls_per_val);
883
    std::uniform_real_distribution<> dis(-2.0, 1.0);
884
885

    for (int i = idx_min; i < idx_max; ++i) {
886
887
888
      data[i] = static_cast<T>(dis(gen_local));
    }
  }
889
  gen->discard(size * rng_calls_per_val);
890
891
}

892
void fillUniform(Tensor *t) {
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
910
  std::uniform_real_distribution<> dis(-2.0, 1.0);
911
912
913
914
915
916
917
918
919
920
921
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
922
      for (size_t i = 0; i < size; ++i) {
923
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
924
      }
925
926
927
    });
  } else {
    double minAbs = -2.0;
928
    double maxAbs = 1.0;
929
930
931
932
933
934
935
936
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
937
938
939
940
941
      for (size_t idx = 0; idx < size; ++idx) {
        const bool is_negative = (dis_sign(t->gen()) < 0.0);
        double val = dis(t->gen());
        if (is_negative) {
          val = -val;
942
        }
943
        data[idx] = static_cast<InputType>(val);
944
945
946
947
      }
    });
  }
  t->set_scale_inv(1.0);
948
949
950
  t->from_cpu();
}

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

967
968
969
970
971
972
973
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
974
975
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
976
977
978
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
979

980
981
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
982
  const float scale = dis(t->gen());
983
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
984
985
}

986
987
988
989
990
991
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

992
bool isFp8Type(DType type) {
993
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
994
995
}

996
997
998
999
bool isFp4Type(DType type) {
  return type == DType::kFloat4E2M1;
}

1000
1001
1002
1003
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
1021
1022
    const bool is_rowwise = (block_size_rows == 1)
                            && ((block_size_cols == 32) || (block_size_cols == 16));
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
1037
1038
}

Przemek Tredak's avatar
Przemek Tredak committed
1039
}  // namespace test