test_common.cu 35.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

Przemek Tredak's avatar
Przemek Tredak committed
24
25
namespace test {

26
27
28
29
30
31
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

48
size_t typeToNumBits(DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
49
50
51
52
53
54
55
56
57
58
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
59
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
64
    {DType::kFloat8E5M2, "float8e5m2"},
65
66
67
68
69
70
    {DType::kFloat8E8M0, "float8e8m0"}
    #if FP4_TYPE_SUPPORTED
    ,
    {DType::kFloat4E2M1, "float4e2m1"}
    #endif
  };
Przemek Tredak's avatar
Przemek Tredak committed
71
72
73
  return name_map.at(type);
}

74
75
76
77
78
79
80
81
82
83
84
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
85
    size_t ret = 1;
86
87
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
88
89
90
91
      ret *= shape.data[i];
    }
    return ret;
}
92

93
94
95
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
96

97
98
99
100
101
102
103
104
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
105

106
107
108
109
110
111
112
113
114
115
116
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
117
118
119
120
  size_t type_size_bits;
  size_t bytes() const noexcept {
    return (product(shape) * type_size_bits) / 8;
  }
121
122
};

123
124
125
126
size_t bytes(const NVTEShape& shape, const DType type) {
  return (product(shape) * typeToNumBits(type)) / 8;
}

127
128
NVTEShape convertShape(const std::vector<size_t>& s) {
  return nvte_make_shape(s.data(), s.size());
129
130
131
132
133
134
135
136
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
137
    ret.type_size_bits = typeToNumBits(DType::kFloat32);
138
139
140
141
142
143
144
145
146
147
148
149
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

150
    auto block_alignment = std::vector<size_t>{128ul, 4ul};
151
152
    {
      auto alignment = block_alignment[0];
153
      auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
154
      alignment = block_alignment[1];
155
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
156
157
158
159
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
      auto alignment = block_alignment[1];
160
      auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
161
      alignment = block_alignment[0];
162
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
163
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
Przemek Tredak's avatar
Przemek Tredak committed
164
    }
165
166
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
167
168
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
169
170
171

    return {ret_rowwise, ret_colwise};
  }
172
173
174
175
176
177
178
179
180
181
182
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
183
184
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
185
186
187
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
188
189
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
190
191
192
193
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
194
195
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
196
197
198
199
200
201
202
203
204
205
206
207
208

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
209
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
210
211
212
213
      auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
214
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
215
216
217
218
219
      auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
220
221
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
222
223
    return {ret_rowwise, ret_colwise};
  }
224
225
226
227
228
229
230
231
232
233
234
235
236

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
               const NVTEScalingMode &scaling_mode) {
  name_ = name;
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
  rowwise_ = rowwise;
  columnwise_ = columnwise;
237
  size_t total_size = bytes(shape, type);
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
  void *dptr_rowwise = nullptr;
  void *dptr_columnwise = nullptr;
  cpu_data_rowwise_ = nullptr;
  cpu_data_columnwise_ = nullptr;
  amax_cpu_data_ = nullptr;
  scale_cpu_data_ = nullptr;
  rowwise_scale_inv_cpu_data_ = nullptr;
  columnwise_scale_inv_cpu_data_ = nullptr;
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
  std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
                                            shape.data[shape.ndim - 1]};
  NVTEShape normalized_shape = convertShape(normalized_shape_v);
254
  NVTEShape columnwise_shape = {};
255
256

  std::vector<size_t> columnwise_shape_vec;
257
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
258
259
260
261
262
263
264
265
266
267
268
    // Transpose when tensor scaling
    columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
    for (size_t i = 0; i < shape.ndim - 1; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  } else {
    // Same shape for MX
    for (size_t i = 0; i < shape.ndim; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  }
269
270

  if (columnwise) {
271
    columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
272
  }
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

  tensor_ = TensorWrapper(scaling_mode);

  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }
  tensor_.set_rowwise_data(dptr_rowwise, type, shape);
  tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);

  if (isFp8Type(type)) {
294
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
295
296
297
298
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                         std::vector<size_t>{1});
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
317
318
      auto [rowwise_scale_meta, colwise_scale_meta] =
          get_scales(normalized_shape, tensor_.scaling_mode());
319
320
      auto rowwise_scale_size = rowwise_scale_meta.bytes();
      auto columnwise_scale_size = colwise_scale_meta.bytes();
321
322
323
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
324
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
325
326
327
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
328
329
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
330
331
332
333
334
335
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
336
337
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
338
      }
339
    }
340
  }
Przemek Tredak's avatar
Przemek Tredak committed
341
342
343
344
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
345
  const size_t size = bytes(s, tensor_.dtype());
346
347
348
349
350
351
352
353
354
355
356
357
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
    cudaMemcpy(cpu_data_columnwise_.get(),
               tensor_.get_columnwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
358
  if (isFp8Type(dtype())) {
359
360
361
362
363
364
365
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
366
367
368
369
370
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
371
372
    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(s, tensor_.scaling_mode());
373
    if (rowwise_) {
374
      auto scale_size = rowwise_scale_meta.bytes();
375
376
377
378
379
380
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
381
      auto scale_size = colwise_scale_meta.bytes();
382
383
384
385
386
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
387
  }
Przemek Tredak's avatar
Przemek Tredak committed
388
389
390
391
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
392
  const size_t size = bytes(s, tensor_.dtype());
393
  if (rowwise_) {
394
395
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
               cudaMemcpyHostToDevice);
396
397
  }
  if (columnwise_) {
398
399
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
               cudaMemcpyHostToDevice);
400
  }
401
  if (isFp8Type(dtype())) {
402
403
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
      if (tensor_.amax() != nullptr){
404
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
405
      }
406
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
407
    }
408
409
    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(s, tensor_.scaling_mode());
410
    if (rowwise_) {
411
      auto scale_size = rowwise_scale_meta.bytes();
412
413
414
415
416
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
417
      auto scale_size = colwise_scale_meta.bytes();
418
419
420
421
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
422
423
424
425
426
427
  }
}

void Tensor::set_scale(float scale) {
  if (isFp8Type(dtype())) {
    NVTE_CHECK(scale_cpu_data_);
428
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
429
430
431
      *scale_cpu_data_ = scale;
      from_cpu();
    }
432
433
434
435
436
  }
}

void Tensor::set_scale_inv(float scale_inv) {
  if (isFp8Type(dtype())) {
437
438
439
440
441
442
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
443
444
445

    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(tensor_.shape(), tensor_.scaling_mode());
446
447
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
448
      if (num_scales == 1) {
449
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
450
      } else {
451
        std::uniform_int_distribution<uint8_t> dis(0, 127);
452
453
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
454
455
456
457
458
459
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
460
      if (num_scales == 1) {
461
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
462
      } else {
463
        std::uniform_int_distribution<uint8_t> dis(0, 127);
464
465
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
466
467
468
469
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
470
471
472
473
474
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
475
  if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
476
477
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
478
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
479
480
481
482
483
484
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
485
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
486
487
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
488
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
489
490
491
492
493
494
495
496
497
498
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
499
500
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
519
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523
524
525
526
527
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

528
529
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
530
531
                               double atol, double rtol, bool if_on_gpus,
                               const size_t tolerable_mismatches_limit) {
532
533
534
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
535
536
  size_t mismatches_num = 0;
  int first_mismatch_idx = -1;
Przemek Tredak's avatar
Przemek Tredak committed
537
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
538
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
539
540
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
541
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
542
543
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
544
545
546
547
#else
      double t = static_cast<double>(static_cast<float>(test_data[i]));
      double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
Przemek Tredak's avatar
Przemek Tredak committed
548
549
550
551
552
553
554
555
556
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
557
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
558
559
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
560
561
562
563
564
565
566
567
#else
        const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
        const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif

#ifdef __HIP_PLATFORM_AMD__
        assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
Przemek Tredak's avatar
Przemek Tredak committed
568
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
569
#endif
Przemek Tredak's avatar
Przemek Tredak committed
570
      }
571
      std::string direction = rowwise ? "rowwise" : "columnwise";
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
      if (assertion) {
        mismatches_num++;
        if (first_mismatch_idx == -1) {
          first_mismatch_idx = i;
        }
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
        const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);

        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                    << tolerable_mismatches_limit << "." << std::endl
                    << "Error in tensor " << name << " in "
                    << direction << " direction." << std::endl
                     << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
                     << " (" << std::to_string(first_mismatch_idx) << "): "
                     << first_mismatch_t << " vs " << first_mismatch_r;
      }
590
591
592
593
594
595
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
596
597
                                  const size_t N, const double atol, const double rtol,
                                  size_t& mismatches) {
598
599
  int first_mismatch_idx = N;

600
601
602
603
604
  #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
  {
    size_t thread_mismatches = 0;
    #pragma omp for schedule(static)
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
605
#ifndef __HIP_PLATFORM_AMD__
606
607
    double t = static_cast<double>(test_data[i]);
    double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
608
609
610
611
#else
    double t = static_cast<double>(static_cast<float>(test_data[i]));
    double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
612
613
614
615
616
617
618
619
620
621

      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && (data_type == DType::kFloat32);
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
622
#ifndef __HIP_PLATFORM_AMD__
623
624
      const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
      const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
625
626
627
628
629
630
631
#else
      const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
      const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif
#ifdef __HIP_PLATFORM_AMD__
      assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
632
      assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
633
#endif
634
635
636
637
638
639
640
      }
      if (assertion) {
        if (i < first_mismatch_idx) {
          first_mismatch_idx = i;
        }
        thread_mismatches++;
      }
641
    }
642
    mismatches += thread_mismatches;
643
644
645
646
647
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
648
649
                             const bool rowwise, double atol, double rtol, bool if_on_gpus,
                             const size_t tolerable_mismatches_limit) {
650
651
652
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
653
  size_t mismatches = 0;
654
655
656
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
657

658
659
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
    if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
yuguo's avatar
yuguo committed
660
#ifndef __HIP_PLATFORM_AMD__
661
662
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
663
664
665
666
#else
      const double t = static_cast<double>(static_cast<float>(test_data[i]));
      const double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
667
      std::string direction = rowwise ? "rowwise" : "columnwise";
668
669
670
671
672
673
674

      GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
                   << tolerable_mismatches_limit << "." << std::endl
                   << "Error in tensor " << name << " in "
                   << direction << " direction." << std::endl
                   << "Mismatch at place " << to_string(unravel(i, shape))
                   << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
675
676
677
678
    }
  );
}

679
void compareResults(const std::string &name, const Tensor &test, const void *ref,
680
681
                    const bool rowwise, double atol, double rtol, bool if_on_gpus,
                    const size_t tolerable_mismatches_limit) {
682
683
  constexpr bool sequential = false;
  if constexpr (sequential) {
684
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
685
  } else {
686
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
687
688
689
  }
}

690
691
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
yuguo's avatar
yuguo committed
692
#ifndef __HIP_PLATFORM_AMD__
693
694
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
yuguo's avatar
yuguo committed
695
696
697
698
#else
  double t = static_cast<double>(static_cast<float>(test));
  double r = static_cast<double>(static_cast<float>(ref));
#endif
699
700
701
702
703
704
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
728
729
730
731
                                    const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                    size_t& mismatches_num, const size_t atol,
                                    const double abs_tolerable_mismatches_limit,
                                    const double rel_tolerable_mismatches_limit)
732
{
733
734
735
736
737
738
  const size_t N = row_blocks * col_blocks;
  const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
                                                     std::floor(N * rel_tolerable_mismatches_limit));
  mismatches_num = 0;
  std::vector<int> mismatch_indices;

739
740
741
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
742
743
744
      const int test_val = static_cast<int>(test[idx]);
      const int ref_val = static_cast<int>(ref[idx]);
      const int abs_delta = std::abs(test_val - ref_val);
745

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
      if (abs_delta > atol) {
        mismatches_num++;
        mismatch_indices.push_back(idx);
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        std::cout << "Error in " << name << std::endl;
        for (const int index : mismatch_indices) {
          std::cout << "Mismatch at (" << index << "):"
                    << static_cast<int>(test[index]) << " vs "
                    << static_cast<int>(ref[index]) << std::endl;
        }
        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                     << tolerable_mismatches_limit << ".";
      }
    }
761
762
763
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
764
765
766
767
768
769
770
771
772
773
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
774
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
775
776
777
778
779
780
781
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

782
783
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
784
785
786
787
788
789
790
791
792
793
794
795
796
  // Check how many RNG calls are required to generate one uniform random value
  int rng_calls_per_val = 0;
  {
    std::mt19937 gen1 = *gen, gen2 = *gen;
    std::uniform_real_distribution<> dis(-2.0, 1.0);
    const float _ = dis(gen1);
    while (gen2 != gen1) {
      auto _ = gen2();
      ++rng_calls_per_val;
    }
  }

  // Generate uniform random values in parallel
797
798
799
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
800
801
802
803
804
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
805
    gen_local.discard(idx_min * rng_calls_per_val);
806
    std::uniform_real_distribution<> dis(-2.0, 1.0);
807
808

    for (int i = idx_min; i < idx_max; ++i) {
yuguo's avatar
yuguo committed
809
#ifndef __HIP_PLATFORM_AMD__
810
      data[i] = static_cast<T>(dis(gen_local));
yuguo's avatar
yuguo committed
811
812
813
#else
      data[i] = static_cast<T>(static_cast<float>(dis(gen_local)));
#endif
814
815
    }
  }
816
  gen->discard(size * rng_calls_per_val);
817
818
}

819
void fillUniform(Tensor *t) {
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
837
  std::uniform_real_distribution<> dis(-2.0, 1.0);
838
839
840
841
842
843
844
845
846
847
848
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
849
      for (size_t i = 0; i < size; ++i) {
850
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
851
      }
852
853
854
    });
  } else {
    double minAbs = -2.0;
855
    double maxAbs = 1.0;
856
857
858
859
860
861
862
863
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
864
865
866
867
868
      for (size_t idx = 0; idx < size; ++idx) {
        const bool is_negative = (dis_sign(t->gen()) < 0.0);
        double val = dis(t->gen());
        if (is_negative) {
          val = -val;
869
        }
870
        data[idx] = static_cast<InputType>(val);
871
872
873
874
      }
    });
  }
  t->set_scale_inv(1.0);
875
876
877
  t->from_cpu();
}

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);

898
899
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
900
  const float scale = dis(t->gen());
901
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
902
903
}

904
905
906
907
908
909
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

910
bool isFp8Type(DType type) {
911
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
912
913
}

914
915
916
917
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
    const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
950
951
}

Przemek Tredak's avatar
Przemek Tredak committed
952
}  // namespace test