test_common.cu 31.9 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

Przemek Tredak's avatar
Przemek Tredak committed
24
25
namespace test {

26
27
28
29
30
31
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

size_t typeToSize(DType type) {
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
59
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
64
65
    {DType::kFloat8E5M2, "float8e5m2"},
    {DType::kFloat8E8M0, "float8e8m0"}};
Przemek Tredak's avatar
Przemek Tredak committed
66
67
68
  return name_map.at(type);
}

69
70
71
72
73
74
75
76
77
78
79
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
80
    size_t ret = 1;
81
82
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
83
84
85
86
      ret *= shape.data[i];
    }
    return ret;
}
87

88
89
90
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
91

92
93
94
95
96
97
98
99
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
  size_t type_size;
};

NVTEShape convertShape(const std::vector<size_t>& shape) {
  return {shape.data(), shape.size()};
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
    ret.type_size = sizeof(float);
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

138
    auto block_alignment = std::vector<size_t>{128ul, 4ul};
139
140
    {
      auto alignment = block_alignment[0];
141
      auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(1)), alignment) * alignment;
142
      alignment = block_alignment[1];
143
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(32)), alignment) * alignment;
144
145
146
147
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
      auto alignment = block_alignment[1];
148
      auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast<size_t>(32)), alignment) * alignment;
149
      alignment = block_alignment[0];
150
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(1)), alignment) * alignment;
151
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
Przemek Tredak's avatar
Przemek Tredak committed
152
    }
153
154
155
156
157
158
159
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
    ret_rowwise.type_size = sizeof(uint8_t);
    ret_colwise.type_size = sizeof(uint8_t);

    return {ret_rowwise, ret_colwise};
  }
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
      auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
    ret_rowwise.type_size = sizeof(float);
    ret_colwise.type_size = sizeof(float);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
      auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
      auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
    ret_rowwise.type_size = sizeof(float);
    ret_colwise.type_size = sizeof(float);
    return {ret_rowwise, ret_colwise};
  }
212
213
214
215
216
217
218

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
219
220
               const NVTEScalingMode &scaling_mode,
               const QuantizationOptions* q_opts) {
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
  name_ = name;
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
  rowwise_ = rowwise;
  columnwise_ = columnwise;
  size_t s = typeToSize(type);
  size_t total_size = product(shape) * s;
  void *dptr_rowwise = nullptr;
  void *dptr_columnwise = nullptr;
  cpu_data_rowwise_ = nullptr;
  cpu_data_columnwise_ = nullptr;
  amax_cpu_data_ = nullptr;
  scale_cpu_data_ = nullptr;
  rowwise_scale_inv_cpu_data_ = nullptr;
  columnwise_scale_inv_cpu_data_ = nullptr;
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
  std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
                                            shape.data[shape.ndim - 1]};
  NVTEShape normalized_shape = convertShape(normalized_shape_v);
244
  NVTEShape columnwise_shape{nullptr, 0};
245
246

  std::vector<size_t> columnwise_shape_vec;
247
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
248
249
250
251
252
253
254
255
256
257
258
    // Transpose when tensor scaling
    columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
    for (size_t i = 0; i < shape.ndim - 1; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  } else {
    // Same shape for MX
    for (size_t i = 0; i < shape.ndim; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  }
259
260
261
262
263

  if (columnwise) {
    columnwise_shape.data = columnwise_shape_vec.data();
    columnwise_shape.ndim = columnwise_shape_vec.size();
  }
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

  tensor_ = TensorWrapper(scaling_mode);

  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }
  tensor_.set_rowwise_data(dptr_rowwise, type, shape);
  tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);

  if (isFp8Type(type)) {
285
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
286
287
288
289
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                         std::vector<size_t>{1});
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
308
309
      auto [rowwise_scale_meta, colwise_scale_meta] =
          get_scales(normalized_shape, tensor_.scaling_mode());
310
311
312
313
314
      auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
      auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
315
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
316
317
318
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
319
320
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
321
322
323
324
325
326
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
327
328
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
329
      }
330
    }
331
332
333
334
    if (q_opts != nullptr) {
      NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation.");
      NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation.");
    }
335
  }
Przemek Tredak's avatar
Przemek Tredak committed
336
337
338
339
340
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
  const size_t size = product(s) * typeToSize(tensor_.dtype());
341
342
343
344
345
346
347
348
349
350
351
352
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
    cudaMemcpy(cpu_data_columnwise_.get(),
               tensor_.get_columnwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
353
  if (isFp8Type(dtype())) {
354
355
356
357
358
359
360
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
361
362
363
364
365
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
366
367
    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(s, tensor_.scaling_mode());
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    if (rowwise_) {
      auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
      auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
382
  }
Przemek Tredak's avatar
Przemek Tredak committed
383
384
385
386
387
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
  const size_t size = product(s) * typeToSize(tensor_.dtype());
388
389
390
391
392
393
394
395
  if (rowwise_) {
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr,
               cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice);
  }
  if (columnwise_) {
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr,
               cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
  }
396
  if (isFp8Type(dtype())) {
397
398
399
400
401
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
      if (tensor_.amax() != nullptr){
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
                  cudaMemcpyHostToDevice);
      }
402
403
404
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
                 cudaMemcpyHostToDevice);
    }
405
406
    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(s, tensor_.scaling_mode());
407
408
409
410
411
412
413
414
415
416
417
418
    if (rowwise_) {
      auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
      auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
419
420
421
422
423
424
  }
}

void Tensor::set_scale(float scale) {
  if (isFp8Type(dtype())) {
    NVTE_CHECK(scale_cpu_data_);
425
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
426
427
428
      *scale_cpu_data_ = scale;
      from_cpu();
    }
429
430
431
432
433
  }
}

void Tensor::set_scale_inv(float scale_inv) {
  if (isFp8Type(dtype())) {
434
435
436
437
438
439
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
440
441
442

    auto [rowwise_scale_meta, colwise_scale_meta] =
        get_scales(tensor_.shape(), tensor_.scaling_mode());
443
444
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
445
      if (num_scales == 1) {
446
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
447
      } else {
448
        std::uniform_int_distribution<uint8_t> dis(0, 127);
449
450
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
451
452
453
454
455
456
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
457
      if (num_scales == 1) {
458
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
459
      } else {
460
        std::uniform_int_distribution<uint8_t> dis(0, 127);
461
462
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
463
464
465
466
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
467
468
469
470
471
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
472
  if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
473
474
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
475
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
476
477
478
479
480
481
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
482
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
483
484
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
485
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
486
487
488
489
490
491
492
493
494
495
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
496
497
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
516
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
517
518
519
520
521
522
523
524
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

525
526
527
528
529
530
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
                               double atol, double rtol, bool if_on_gpus) {
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
Przemek Tredak's avatar
Przemek Tredak committed
531
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
532
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
      }
550
551
552
553
      std::string direction = rowwise ? "rowwise" : "columnwise";
      ASSERT_FALSE(assertion) << "Error in tensor " << name << " in "
                              << direction << " direction." << std::endl
                              << "Mismatch at place " << to_string(unravel(i, shape))
Przemek Tredak's avatar
Przemek Tredak committed
554
                              << " (" << std::to_string(i) << "): " << t << " vs " << r;
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
                                  const size_t N, const double atol, const double rtol) {
  int first_mismatch_idx = N;

  bool is_mismatch_found = false;
  #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \
    reduction(min: first_mismatch_idx) proc_bind(spread)
  for (size_t i = 0; i < N; ++i) {
    if (is_mismatch_found) {    // early escape of the omp thread
      continue;
    }

    double t = static_cast<double>(test_data[i]);
    double r = static_cast<double>(ref_data[i]);

    bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
    /* For Float32 the floating point comparison is enough to error out */
    bool assertion = mismatch && (data_type == DType::kFloat32);
    if (mismatch && !assertion) {
      /* Check if it is just a failure of round to nearest choosing different
          side of the real value */
      const double mean = (t + r) / 2;
      const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
      const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
      const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
      const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
      assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
    }
    if (assertion && i < first_mismatch_idx) {
      first_mismatch_idx = i;
      is_mismatch_found = true;
    }
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
                             const bool rowwise, double atol, double rtol, bool if_on_gpus) {
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
604

605
606
607
608
609
610
611
612
613
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol);
    if (i != N) {
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
      std::string direction = rowwise ? "rowwise" : "columnwise";
      ASSERT_FALSE(true) << "Error in tensor " << name << " in "
                         << direction << " direction." << std::endl
                         << "Mismatch at place " << to_string(unravel(i, shape))
                         << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
614
615
616
617
    }
  );
}

618
619
620
621
622
623
624
625
626
627
void compareResults(const std::string &name, const Tensor &test, const void *ref,
                    const bool rowwise, double atol, double rtol, bool if_on_gpus) {
  constexpr bool sequential = false;
  if constexpr (sequential) {
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus);
  } else {
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus);
  }
}

628
629
630
631
632
633
634
635
636
637
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                  const size_t row_blocks, const size_t col_blocks, const size_t stride)
{
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
      ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl
        << "Mismatch: " << static_cast<int>(test[idx]) << " vs "
        << static_cast<int>(ref[idx]) << " at index " << idx;
    }
  }
}

void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                  const size_t N)
{
  for (int i = 0; i < N; i++) {
    ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl
      << "Mismatch: " << static_cast<int>(test[i]) << " vs "
      << static_cast<int>(ref[i]) << " at index " << i;
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
683
684
685
686
687
688
689
690
691
692
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
693
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
694
695
696
697
698
699
700
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

701
702
703
704
705
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
706
707
708
709
710
711
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
    gen_local.discard(idx_min);
712
    std::uniform_real_distribution<> dis(-2.0, 1.0);
713
714

    for (int i = idx_min; i < idx_max; ++i) {
715
716
717
718
719
720
      data[i] = static_cast<T>(dis(gen_local));
    }
  }
  gen->discard(size);
}

721
void fillUniform(Tensor *t) {
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
739
  std::uniform_real_distribution<> dis(-2.0, 1.0);
740
741
742
743
744
745
746
747
748
749
750
751
752
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());
  const size_t rows = t->rowwise_shape().data[0];
  const size_t cols = t->rowwise_shape().data[1];

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
753
      for (size_t i = 0; i < size; ++i) {
754
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
755
      }
756
757
758
    });
  } else {
    double minAbs = -2.0;
759
    double maxAbs = 1.0;
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
      for (size_t i = 0; i < rows; ++i) {
        for (size_t j = 0; j < cols; ++j) {
          const size_t idx = i * cols + j;
          const bool is_negative = (dis_sign(t->gen()) < 0.0);
          double val = dis(t->gen());
          if (is_negative) {
            val = -val;
          }
          data[idx] = static_cast<InputType>(val);
        }
      }
    });
  }
  t->set_scale_inv(1.0);
782
783
784
  t->from_cpu();
}

785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);

805
806
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
807
  const float scale = dis(t->gen());
808
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
809
810
}

811
812
813
814
815
816
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

817
bool isFp8Type(DType type) {
818
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
819
820
}

821
822
823
824
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
    const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
857
858
}

Przemek Tredak's avatar
Przemek Tredak committed
859
}  // namespace test