test_common.cu 42.5 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

maxiao3's avatar
maxiao3 committed
24
25
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)

Przemek Tredak's avatar
Przemek Tredak committed
26
27
namespace test {

28
29
30
31
32
33
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

50
size_t typeToNumBits(DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
51
52
53
54
55
56
57
58
59
60
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
61
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
62
63
64
65
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
66
    {DType::kFloat8E5M2, "float8e5m2"},
67
68
69
70
71
72
    {DType::kFloat8E8M0, "float8e8m0"}
    #if FP4_TYPE_SUPPORTED
    ,
    {DType::kFloat4E2M1, "float4e2m1"}
    #endif
  };
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
  return name_map.at(type);
}

76
77
78
79
80
81
82
83
84
85
86
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
87
    size_t ret = 1;
88
89
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
90
91
92
93
      ret *= shape.data[i];
    }
    return ret;
}
94

95
96
97
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
98

99
100
101
102
103
104
105
106
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
107

108
109
110
111
112
113
114
115
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

116
117
118
119
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
  return DIVUP(x, y) * y;
}

120
121
122
struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
123
124
125
126
  size_t type_size_bits;
  size_t bytes() const noexcept {
    return (product(shape) * type_size_bits) / 8;
  }
127
128
};

129
130
131
132
size_t bytes(const NVTEShape& shape, const DType type) {
  return (product(shape) * typeToNumBits(type)) / 8;
}

133
134
NVTEShape convertShape(const std::vector<size_t>& s) {
  return nvte_make_shape(s.data(), s.size());
135
136
137
138
139
140
141
142
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
143
    ret.type_size_bits = typeToNumBits(DType::kFloat32);
144
145
146
147
148
149
150
151
152
153
154
155
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

    ret_rowwise.type = DType::kFloat8E8M0;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type = DType::kFloat8E8M0;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
177
    }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    NVTE_CHECK(last_dim % 32 == 0);
    NVTE_CHECK(first_dim % 32 == 0);

    scale_inv_meta ret_rowwise, ret_colwise;

    size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y, scale_dim_X};

    size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};

    ret_rowwise.type = DType::kFloat8E4M3;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
    ret_colwise.type = DType::kFloat8E4M3;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
Przemek Tredak's avatar
Przemek Tredak committed
205
    }
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

221
222
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
223
224
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
225
226
227

    return {ret_rowwise, ret_colwise};
  }
228
229
230
231
232
233
234
235
236
237
238
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
239
240
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
241
242
243
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
244
245
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
246
247
248
249
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
250
251
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
252
253
254
255
256
257
258
259
260
261
262
263
264

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
265
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
266
267
268
269
      auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
270
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
271
272
273
274
275
      auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
276
277
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
278
279
    return {ret_rowwise, ret_colwise};
  }
280
281
282
283
284
285
286

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
287
288
289
               const NVTEScalingMode &scaling_mode)
  : tensor_(scaling_mode), rowwise_{rowwise}, columnwise_{columnwise}, name_{name} {
  // Initialize RNG
290
291
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
292
293

  // Make sure shape is valid
294
295
296
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
297

298
299
300
301
302
303
304
305
306
307
308
  // Shape after flattening to 2D
  NVTEShape flattened_shape;
  {
    std::vector<size_t> flattened_shape_vec;
    if (shape.ndim > 0) {
      flattened_shape_vec.push_back(product(shape, 0, shape.ndim - 1));
      flattened_shape_vec.push_back(shape.data[shape.ndim - 1]);
    } else {
      flattened_shape_vec.resize(2, 1);
    }
    flattened_shape = convertShape(flattened_shape_vec);
309
  }
310

311
312
313
  // Allocate and initialize data
  void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr;
  const size_t total_size = bytes(shape, type);
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }

329
330
  // Set tensor row-wise data
  if (rowwise) {
wenjh's avatar
wenjh committed
331
#if FP4_TYPE_SUPPORTED
332
333
    const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
    tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
wenjh's avatar
wenjh committed
334
335
#else
    tensor_.set_rowwise_data(dptr_rowwise, type, shape);
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
  }

  // Set tensor column-wise data
  if (columnwise) {
    // Determine shape of column-wise data
    std::vector<size_t> columnwise_shape_vec;
    switch (scaling_mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
    case NVTE_BLOCK_SCALING_1D:
    case NVTE_BLOCK_SCALING_2D: {
      // Column-wise data shape is transposed
      if (shape.ndim > 0) {
        columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
        for (size_t i = 0; i < shape.ndim - 1; ++i) {
          columnwise_shape_vec.emplace_back(shape.data[i]);
        }
      }
      break;
    }
    case NVTE_MXFP8_1D_SCALING:
    case NVTE_NVFP4_1D_SCALING: {
      // Column-wise data matches shape
      for (size_t i = 0; i < shape.ndim; ++i) {
        columnwise_shape_vec.emplace_back(shape.data[i]);
      }
      break;
    }
    default:
      NVTE_ERROR("Unrecognized scaling mode (", (size_t)scaling_mode, ").");
    }
    const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(),
                                                  columnwise_shape_vec.size());

maxiao3's avatar
maxiao3 committed
369
#if FP4_TYPE_SUPPORTED
370
371
372
    // Set column-wise data buffer
    const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
    tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
maxiao3's avatar
maxiao3 committed
373
#else
wenjh's avatar
wenjh committed
374
    tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
375
  }
376

377
378
379
  // Configure scales, amaxes, and other tensor buffers
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
380
  if (isFp8Type(type) || isFp4Type(type)) {
381
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
382
383
384
385
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
386
387
388
389
390
391
392
393
394
395
396
397
398
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
399
                                          std::vector<size_t>{1});
400
401
402
403
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
404
405
406
407
408
409
410
      if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
        // Used for NVFP4 second stage scaling
        cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
        cudaMemset(scale, 0, sizeof(float));
        scale_cpu_data_ = std::make_shared<float>(0);
        tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      }
411
      auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode());
412
413
      auto rowwise_scale_size = rowwise_scale_meta.bytes();
      auto columnwise_scale_size = colwise_scale_meta.bytes();
414
415
416
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
417
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
418
419
420
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
421
422
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
423
424
425
426
427
428
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
429
430
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
431
      }
432
    }
433
  }
Przemek Tredak's avatar
Przemek Tredak committed
434
435
436
437
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
438
  const size_t size = bytes(s, tensor_.dtype());
439
440
441
442
443
444
445
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
446
447
448
    const DType colwise_type = tensor_.dtype();

    const size_t colwise_size = bytes(s, colwise_type);
449
    cudaMemcpy(cpu_data_columnwise_.get(),
450
451
452
                tensor_.get_columnwise_data().data_ptr,
                colwise_size,
                cudaMemcpyDeviceToHost);
453
  }
454
455
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
456
457
458
459
460
461
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
462
463
464
465
466
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
467
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
468
    if (rowwise_) {
469
      auto scale_size = rowwise_scale_meta.bytes();
470
471
472
473
474
475
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
476
      auto scale_size = colwise_scale_meta.bytes();
477
478
479
480
481
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
482
  }
Przemek Tredak's avatar
Przemek Tredak committed
483
484
485
486
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
487
  const size_t size = bytes(s, tensor_.dtype());
488
  if (rowwise_) {
489
490
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
               cudaMemcpyHostToDevice);
491
492
  }
  if (columnwise_) {
493
494
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
               cudaMemcpyHostToDevice);
495
  }
496
497
498
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
        || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
499
      if (tensor_.amax() != nullptr){
500
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
501
      }
502
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
503
    }
504
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
505
    if (rowwise_) {
506
      auto scale_size = rowwise_scale_meta.bytes();
507
508
509
510
511
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
512
      auto scale_size = colwise_scale_meta.bytes();
513
514
515
516
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
517
518
519
520
  }
}

void Tensor::set_scale(float scale) {
521
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
522
    NVTE_CHECK(scale_cpu_data_);
523
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
524
525
526
      *scale_cpu_data_ = scale;
      from_cpu();
    }
527
528
529
530
  }
}

void Tensor::set_scale_inv(float scale_inv) {
531
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
532
533
534
535
536
537
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
538

539
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
540
541
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
542
      if (num_scales == 1) {
543
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
544
      } else {
545
        std::uniform_int_distribution<uint8_t> dis(0, 127);
546
547
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
548
549
550
551
552
553
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
554
      if (num_scales == 1) {
555
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
556
      } else {
557
        std::uniform_int_distribution<uint8_t> dis(0, 127);
558
559
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
560
561
562
563
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
564
565
566
567
568
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
569
570
  if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
      || isFp4Type(dtype()) && isFp4Type(other.dtype())) {
571
572
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
573
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
574
575
576
577
578
579
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
580
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
581
582
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
583
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
584
585
586
587
588
589
590
591
592
593
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
594
595
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
614
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
615
616
617
618
619
620
621
622
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

623
624
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
625
626
                               double atol, double rtol, bool if_on_gpus,
                               const size_t tolerable_mismatches_limit) {
627
628
629
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
630
631
  size_t mismatches_num = 0;
  int first_mismatch_idx = -1;
Przemek Tredak's avatar
Przemek Tredak committed
632
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
633
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
634
635
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
636
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
637
638
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
639
640
641
642
#else
      double t = static_cast<double>(static_cast<float>(test_data[i]));
      double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
Przemek Tredak's avatar
Przemek Tredak committed
643
644
645
646
647
648
649
650
651
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
652
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
653
654
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
655
656
657
658
659
660
661
662
#else
        const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
        const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif

#ifdef __HIP_PLATFORM_AMD__
        assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
Przemek Tredak's avatar
Przemek Tredak committed
663
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
664
#endif
Przemek Tredak's avatar
Przemek Tredak committed
665
      }
666
      std::string direction = rowwise ? "rowwise" : "columnwise";
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
      if (assertion) {
        mismatches_num++;
        if (first_mismatch_idx == -1) {
          first_mismatch_idx = i;
        }
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
        const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);

        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                    << tolerable_mismatches_limit << "." << std::endl
                    << "Error in tensor " << name << " in "
                    << direction << " direction." << std::endl
                     << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
                     << " (" << std::to_string(first_mismatch_idx) << "): "
                     << first_mismatch_t << " vs " << first_mismatch_r;
      }
685
686
687
688
689
690
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
691
692
                                  const size_t N, const double atol, const double rtol,
                                  size_t& mismatches) {
693
694
  int first_mismatch_idx = N;

695
696
697
698
699
  #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
  {
    size_t thread_mismatches = 0;
    #pragma omp for schedule(static)
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
700
#ifndef __HIP_PLATFORM_AMD__
701
702
    double t = static_cast<double>(test_data[i]);
    double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
703
704
705
706
#else
    double t = static_cast<double>(static_cast<float>(test_data[i]));
    double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
707
708
709
710
711
712
713
714
715
716

      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && (data_type == DType::kFloat32);
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
717
#ifndef __HIP_PLATFORM_AMD__
718
719
      const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
      const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
720
721
722
723
724
725
726
#else
      const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
      const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif
#ifdef __HIP_PLATFORM_AMD__
      assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
727
      assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
728
#endif
729
730
731
732
733
734
735
      }
      if (assertion) {
        if (i < first_mismatch_idx) {
          first_mismatch_idx = i;
        }
        thread_mismatches++;
      }
736
    }
737
    mismatches += thread_mismatches;
738
739
740
741
742
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
743
744
                             const bool rowwise, double atol, double rtol, bool if_on_gpus,
                             const size_t tolerable_mismatches_limit) {
745
746
747
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
748
  size_t mismatches = 0;
749
750
751
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
752

753
754
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
    if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
yuguo's avatar
yuguo committed
755
#ifndef __HIP_PLATFORM_AMD__
756
757
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
758
759
760
761
#else
      const double t = static_cast<double>(static_cast<float>(test_data[i]));
      const double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
762
      std::string direction = rowwise ? "rowwise" : "columnwise";
763
764
765
766
767
768
769

      GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
                   << tolerable_mismatches_limit << "." << std::endl
                   << "Error in tensor " << name << " in "
                   << direction << " direction." << std::endl
                   << "Mismatch at place " << to_string(unravel(i, shape))
                   << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
770
771
772
773
    }
  );
}

774
void compareResults(const std::string &name, const Tensor &test, const void *ref,
775
776
                    const bool rowwise, double atol, double rtol, bool if_on_gpus,
                    const size_t tolerable_mismatches_limit) {
777
778
  constexpr bool sequential = false;
  if constexpr (sequential) {
779
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
780
  } else {
781
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
782
783
784
  }
}

785
786
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
yuguo's avatar
yuguo committed
787
#ifndef __HIP_PLATFORM_AMD__
788
789
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
yuguo's avatar
yuguo committed
790
791
792
793
#else
  double t = static_cast<double>(static_cast<float>(test));
  double r = static_cast<double>(static_cast<float>(ref));
#endif
794
795
796
797
798
799
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
template <typename T>
struct CastToType;

template <>
struct CastToType<uint8_t> {
  using type = int;
};

template <>
struct CastToType<fp8e4m3> {
  using type = float;
};

template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
                             const size_t row_blocks, const size_t col_blocks, const size_t stride,
                             size_t& mismatches_num, const size_t atol,
                             const double abs_tolerable_mismatches_limit,
                             const double rel_tolerable_mismatches_limit)
841
{
842
843
844
845
  using UpcastType = typename CastToType<T>::type;
  auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);


846
847
848
849
850
851
  const size_t N = row_blocks * col_blocks;
  const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
                                                     std::floor(N * rel_tolerable_mismatches_limit));
  mismatches_num = 0;
  std::vector<int> mismatch_indices;

852
853
854
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
855
856
857
      float t, r;

      bool assertion = false;
858

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
      if (std::is_same<T, uint8_t>::value) {
        t = static_cast<float>(test[idx]);
        r = static_cast<float>(ref[idx]);
        assertion = std::abs(t - r) > atol;
      } else {
        t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
        r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
        const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
                              && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
        if (mismatch) {
          /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
          const double mean = (t + r) / 2;
          const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
          const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
          const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
          const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
          assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
        }
      }
      if (assertion) {
880
881
882
883
884
885
886
        mismatches_num++;
        mismatch_indices.push_back(idx);
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        std::cout << "Error in " << name << std::endl;
        for (const int index : mismatch_indices) {
          std::cout << "Mismatch at (" << index << "):"
887
888
                    << static_cast<UpcastType>(test[index]) << " vs "
                    << static_cast<UpcastType>(ref[index]) << std::endl;
889
890
891
892
893
        }
        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                     << tolerable_mismatches_limit << ".";
      }
    }
894
895
896
  }
}

897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);

template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);


Przemek Tredak's avatar
Przemek Tredak committed
913
914
915
916
917
918
919
920
921
922
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
923
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
924
925
926
927
928
929
930
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

931
932
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
933
934
935
936
937
938
939
940
941
942
943
944
945
  // Check how many RNG calls are required to generate one uniform random value
  int rng_calls_per_val = 0;
  {
    std::mt19937 gen1 = *gen, gen2 = *gen;
    std::uniform_real_distribution<> dis(-2.0, 1.0);
    const float _ = dis(gen1);
    while (gen2 != gen1) {
      auto _ = gen2();
      ++rng_calls_per_val;
    }
  }

  // Generate uniform random values in parallel
946
947
948
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
949
950
951
952
953
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
954
    gen_local.discard(idx_min * rng_calls_per_val);
955
    std::uniform_real_distribution<> dis(-2.0, 1.0);
956
957

    for (int i = idx_min; i < idx_max; ++i) {
yuguo's avatar
yuguo committed
958
#ifndef __HIP_PLATFORM_AMD__
959
      data[i] = static_cast<T>(dis(gen_local));
yuguo's avatar
yuguo committed
960
961
962
#else
      data[i] = static_cast<T>(static_cast<float>(dis(gen_local)));
#endif
963
964
    }
  }
965
  gen->discard(size * rng_calls_per_val);
966
967
}

968
void fillUniform(Tensor *t) {
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
986
  std::uniform_real_distribution<> dis(-2.0, 1.0);
987
988
989
990
991
992
993
994
995
996
997
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
998
      for (size_t i = 0; i < size; ++i) {
999
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
1000
      }
1001
1002
1003
    });
  } else {
    double minAbs = -2.0;
1004
    double maxAbs = 1.0;
1005
1006
1007
1008
1009
1010
1011
1012
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
1013
1014
1015
1016
1017
      for (size_t idx = 0; idx < size; ++idx) {
        const bool is_negative = (dis_sign(t->gen()) < 0.0);
        double val = dis(t->gen());
        if (is_negative) {
          val = -val;
1018
        }
1019
        data[idx] = static_cast<InputType>(val);
1020
1021
1022
1023
      }
    });
  }
  t->set_scale_inv(1.0);
1024
1025
1026
  t->from_cpu();
}

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

1043
1044
1045
1046
1047
1048
1049
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
1050
1051
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
1052
1053
1054
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
1055

1056
1057
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
1058
  const float scale = dis(t->gen());
1059
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
1060
1061
}

1062
1063
1064
1065
1066
1067
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

1068
bool isFp8Type(DType type) {
1069
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
1070
1071
}

1072
bool isFp4Type(DType type) {
maxiao3's avatar
maxiao3 committed
1073
#if FP4_TYPE_SUPPORTED
1074
  return type == DType::kFloat4E2M1;
maxiao3's avatar
maxiao3 committed
1075
1076
1077
#else
  return false;
#endif
1078
1079
}

1080
1081
1082
1083
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
1101
1102
    const bool is_rowwise = (block_size_rows == 1)
                            && ((block_size_cols == 32) || (block_size_cols == 16));
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
1117
1118
}

Przemek Tredak's avatar
Przemek Tredak committed
1119
}  // namespace test