test_common.cu 42 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/


#include "test_common.h"
Tim Moon's avatar
Tim Moon committed
9

Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
#include <algorithm>
#include <memory>
#include <random>
13
#include <iostream>
14
15
16
#include <cassert>
#include <cmath>
#include <string>
Przemek Tredak's avatar
Przemek Tredak committed
17

Tim Moon's avatar
Tim Moon committed
18
#include <gtest/gtest.h>
19
#include <omp.h>
Tim Moon's avatar
Tim Moon committed
20
21
22
23

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"

maxiao3's avatar
maxiao3 committed
24
25
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)

Przemek Tredak's avatar
Przemek Tredak committed
26
27
namespace test {

28
29
30
31
32
33
size_t create_seed_from_tensor_name(const std::string& tensor_name) {
  auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) +
                   "/" + tensor_name;
  return std::hash<std::string>{}(full_name);
}

Przemek Tredak's avatar
Przemek Tredak committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
std::vector<DType> all_fp_types = {DType::kFloat32,
                                   DType::kFloat16,
                                   DType::kBFloat16,
                                   DType::kFloat8E5M2,
                                   DType::kFloat8E4M3};

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
  if (s1.ndim != s2.ndim) return false;

  for (size_t i = 0; i < s1.ndim; ++i) {
    if (s1.data[i] != s2.data[i]) return false;
  }

  return true;
}

50
size_t typeToNumBits(DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
51
52
53
54
55
56
57
58
59
60
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
  {
      return TypeInfo<T>::size;
  });
}

const std::string &typeName(DType type) {
  static const std::unordered_map<DType, std::string> name_map = {
    {DType::kByte, "byte"},
    {DType::kInt32, "int32"},
cyanguwa's avatar
cyanguwa committed
61
    {DType::kInt64, "int64"},
Przemek Tredak's avatar
Przemek Tredak committed
62
63
64
65
    {DType::kFloat32, "float32"},
    {DType::kFloat16, "float16"},
    {DType::kBFloat16, "bfloat16"},
    {DType::kFloat8E4M3, "float8e4m3"},
66
    {DType::kFloat8E5M2, "float8e5m2"},
67
68
69
70
71
72
    {DType::kFloat8E8M0, "float8e8m0"}
    #if FP4_TYPE_SUPPORTED
    ,
    {DType::kFloat4E2M1, "float4e2m1"}
    #endif
  };
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
  return name_map.at(type);
}

76
77
78
79
80
81
82
83
84
85
86
const std::string& caseName(InputsFillCase type) {
  static const std::unordered_map<InputsFillCase, std::string> name_map = {
    {InputsFillCase::uniform, "uniform"},
    {InputsFillCase::zeros, "zeros"},
    {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"},
    {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"},
    {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}};
  return name_map.at(type);
}

size_t product(const NVTEShape &shape, size_t begin, size_t end) {
Przemek Tredak's avatar
Przemek Tredak committed
87
    size_t ret = 1;
88
89
    NVTE_CHECK(end <= shape.ndim);
    for (size_t i = begin; i < end; ++i) {
Przemek Tredak's avatar
Przemek Tredak committed
90
91
92
93
      ret *= shape.data[i];
    }
    return ret;
}
94

95
96
97
size_t product(const NVTEShape &shape) {
  return product(shape, 0, shape.ndim);
}
98

99
100
101
102
103
104
105
106
size_t product(const std::vector<size_t> shape, size_t begin, size_t end) {
    size_t ret = 1;
    NVTE_CHECK(end <= shape.size());
    for (size_t i = begin; i < end; ++i) {
      ret *= shape[i];
    }
    return ret;
}
Przemek Tredak's avatar
Przemek Tredak committed
107

108
109
110
111
112
113
114
115
size_t product(const std::vector<size_t>& shape) {
  return product(shape, 0, shape.size());
}

size_t DIVUP(const size_t &x, const size_t &y){
  return (((x) + ((y)-1)) / (y));
}

116
117
118
119
size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
  return DIVUP(x, y) * y;
}

120
121
122
struct scale_inv_meta {
  std::vector<size_t> shape;
  DType type;
123
124
125
126
  size_t type_size_bits;
  size_t bytes() const noexcept {
    return (product(shape) * type_size_bits) / 8;
  }
127
128
};

129
130
131
132
size_t bytes(const NVTEShape& shape, const DType type) {
  return (product(shape) * typeToNumBits(type)) / 8;
}

133
134
NVTEShape convertShape(const std::vector<size_t>& s) {
  return nvte_make_shape(s.data(), s.size());
135
136
137
138
139
140
141
142
}

std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
                                                     const NVTEScalingMode scaling_mode) {
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    scale_inv_meta ret;
    ret.shape = {1};
    ret.type = DType::kFloat32;
143
    ret.type_size_bits = typeToNumBits(DType::kFloat32);
144
145
146
147
148
149
150
151
152
153
154
155
    return {ret, ret};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

    ret_rowwise.type = DType::kFloat8E8M0;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type = DType::kFloat8E8M0;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
177
    }
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    NVTE_CHECK(last_dim % 32 == 0);
    NVTE_CHECK(first_dim % 32 == 0);

    scale_inv_meta ret_rowwise, ret_colwise;

    size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y, scale_dim_X};

    size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
    ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};

    ret_rowwise.type = DType::kFloat8E4M3;
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
    ret_colwise.type = DType::kFloat8E4M3;
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
Przemek Tredak's avatar
Przemek Tredak committed
205
    }
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    const size_t block_size_X_rowwise = 32;
    size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
    size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
    ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};

    const size_t block_size_Y_colwise = 32;
    size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
    size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
    ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};

221
222
    ret_rowwise.type = DType::kFloat8E8M0;
    ret_colwise.type = DType::kFloat8E8M0;
223
224
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
225
226
227

    return {ret_rowwise, ret_colwise};
  }
228
229
230
231
232
233
234
235
236
237
238
  if (scaling_mode == NVTE_BLOCK_SCALING_2D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);

    scale_inv_meta ret_rowwise, ret_colwise;

    {
239
240
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
241
242
243
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
244
245
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
      auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
246
247
248
249
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
250
251
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
252
253
254
255
256
257
258
259
260
261
262
263
264

    return {ret_rowwise, ret_colwise};
  }
  if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
    std::vector<size_t> shape_vec;
    for (size_t i = 0; i < shape.ndim; ++i) {
      shape_vec.push_back(shape.data[i]);
    }
    size_t first_dim = first_dimension(shape_vec);
    size_t last_dim = last_dimension(shape_vec);
    scale_inv_meta ret_rowwise, ret_colwise;

    {
265
      auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
266
267
268
269
      auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
      ret_rowwise.shape = {scale_dim_0, scale_dim_1};
    }
    {
270
      auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
271
272
273
274
275
      auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
      ret_colwise.shape = {scale_dim_0, scale_dim_1};
    }
    ret_rowwise.type = DType::kFloat32;
    ret_colwise.type = DType::kFloat32;
276
277
    ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
    ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
278
279
    return {ret_rowwise, ret_colwise};
  }
280
281
282
283
284
285
286
287
288
289
290
291
292

  NVTE_ERROR("Invalid scaling mode!");
}

Tensor::Tensor(const std::string& name,
               const NVTEShape &shape, const DType type,
               const bool rowwise, const bool columnwise,
               const NVTEScalingMode &scaling_mode) {
  name_ = name;
  const size_t seed = create_seed_from_tensor_name(name);
  gen_.seed(seed);
  rowwise_ = rowwise;
  columnwise_ = columnwise;
293
  size_t total_size = bytes(shape, type);
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
  void *dptr_rowwise = nullptr;
  void *dptr_columnwise = nullptr;
  cpu_data_rowwise_ = nullptr;
  cpu_data_columnwise_ = nullptr;
  amax_cpu_data_ = nullptr;
  scale_cpu_data_ = nullptr;
  rowwise_scale_inv_cpu_data_ = nullptr;
  columnwise_scale_inv_cpu_data_ = nullptr;
  float *amax = nullptr, *scale = nullptr;
  float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
  if (columnwise) {
    NVTE_CHECK(shape.ndim >= 2);
  }
  std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
                                            shape.data[shape.ndim - 1]};
  NVTEShape normalized_shape = convertShape(normalized_shape_v);
310
  NVTEShape columnwise_shape = {};
311
312

  std::vector<size_t> columnwise_shape_vec;
313
314
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
      || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
315
316
317
318
319
320
    // Transpose when tensor scaling
    columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
    for (size_t i = 0; i < shape.ndim - 1; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  } else {
321
    // Same shape for MX and NVFP4
322
323
324
325
    for (size_t i = 0; i < shape.ndim; ++i) {
      columnwise_shape_vec.emplace_back(shape.data[i]);
    }
  }
326
327

  if (columnwise) {
328
    columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
329
  }
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

  tensor_ = TensorWrapper(scaling_mode);

  if (total_size != 0) {
    if (rowwise) {
      cudaMalloc((void**)&dptr_rowwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_rowwise, 0, total_size);
      cpu_data_rowwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_rowwise_.get(), total_size, 0);
    }
    if (columnwise) {
      cudaMalloc((void**)&dptr_columnwise, total_size);  // NOLINT(*)
      cudaMemset(dptr_columnwise, 0, total_size);
      cpu_data_columnwise_ = std::make_unique<unsigned char[]>(total_size);
      std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
    }
  }

maxiao3's avatar
maxiao3 committed
348
#if FP4_TYPE_SUPPORTED
349
350
351
352
  const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
  const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
  tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
  tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
maxiao3's avatar
maxiao3 committed
353
354
355
356
#else
  tensor_.set_rowwise_data(dptr_rowwise, type, shape);
  tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
#endif
357
358

  if (isFp8Type(type) || isFp4Type(type)) {
359
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
360
361
362
363
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
      cudaMemset(amax, 0, sizeof(float));
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
      cudaMemset(scale, 0, sizeof(float));
364
365
366
367
368
369
370
371
372
373
374
375
376
      amax_cpu_data_ = std::make_shared<float>(0);
      scale_cpu_data_ = std::make_shared<float>(0);
      tensor_.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      cudaMalloc((void**)&rowwise_scale_inv, sizeof(float));  // NOLINT(*)
      if (rowwise) {
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
                                      std::vector<size_t>{1});
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
      if (columnwise) {
        tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
377
                                          std::vector<size_t>{1});
378
379
380
381
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(sizeof(float));
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
      }
    } else {
382
383
384
385
386
387
388
389
      if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
        // Used for NVFP4 second stage scaling
        cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
        cudaMemset(scale, 0, sizeof(float));
        scale_cpu_data_ = std::make_shared<float>(0);
        tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      }
      auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
390
391
      auto rowwise_scale_size = rowwise_scale_meta.bytes();
      auto columnwise_scale_size = colwise_scale_meta.bytes();
392
393
394
      auto scale_shape = rowwise_scale_meta.shape;
      auto columnwise_scale_shape = colwise_scale_meta.shape;
      if (rowwise) {
395
        cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size);  // NOLINT(*)
396
397
398
        cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size);
        rowwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(rowwise_scale_size);
        std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0);
399
400
        auto scale_dtype = rowwise_scale_meta.type;
        tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape);
401
402
403
404
405
406
      }
      if (columnwise) {
        cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size);  // NOLINT(*)
        cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size);
        columnwise_scale_inv_cpu_data_ = std::make_unique<unsigned char[]>(columnwise_scale_size);
        std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0);
407
408
        auto scale_dtype = colwise_scale_meta.type;
        tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
409
      }
410
    }
411
  }
Przemek Tredak's avatar
Przemek Tredak committed
412
413
414
415
}

void Tensor::to_cpu() const {
  const NVTEShape s = tensor_.shape();
416
  const size_t size = bytes(s, tensor_.dtype());
417
418
419
420
421
422
423
  if (rowwise_) {
    cudaMemcpy(cpu_data_rowwise_.get(),
               tensor_.get_rowwise_data().data_ptr,
               size,
               cudaMemcpyDeviceToHost);
  }
  if (columnwise_) {
424
425
426
    const DType colwise_type = tensor_.dtype();

    const size_t colwise_size = bytes(s, colwise_type);
427
    cudaMemcpy(cpu_data_columnwise_.get(),
428
429
430
                tensor_.get_columnwise_data().data_ptr,
                colwise_size,
                cudaMemcpyDeviceToHost);
431
  }
432
433
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
434
435
436
437
438
439
      if (tensor_.amax() != nullptr){
        cudaMemcpy(amax_cpu_data_.get(),
                  tensor_.amax(),
                  sizeof(float),
                  cudaMemcpyDeviceToHost);
      }
440
441
442
443
444
      cudaMemcpy(scale_cpu_data_.get(),
                 tensor_.scale(),
                 sizeof(float),
                 cudaMemcpyDeviceToHost);
    }
445
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
446
    if (rowwise_) {
447
      auto scale_size = rowwise_scale_meta.bytes();
448
449
450
451
452
453
      cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
                 tensor_.get_rowwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
    if (columnwise_) {
454
      auto scale_size = colwise_scale_meta.bytes();
455
456
457
458
459
      cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
                 tensor_.get_columnwise_scale_inv().data_ptr,
                 scale_size,
                 cudaMemcpyDeviceToHost);
    }
460
  }
Przemek Tredak's avatar
Przemek Tredak committed
461
462
463
464
}

void Tensor::from_cpu() const {
  const NVTEShape s = tensor_.shape();
465
  const size_t size = bytes(s, tensor_.dtype());
466
  if (rowwise_) {
467
468
    cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
               cudaMemcpyHostToDevice);
469
470
  }
  if (columnwise_) {
471
472
    cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
               cudaMemcpyHostToDevice);
473
  }
474
475
476
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
    if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
        || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
477
      if (tensor_.amax() != nullptr){
478
        cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
479
      }
480
      cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
481
    }
482
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
483
    if (rowwise_) {
484
      auto scale_size = rowwise_scale_meta.bytes();
485
486
487
488
489
      cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
                 rowwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
    if (columnwise_) {
490
      auto scale_size = colwise_scale_meta.bytes();
491
492
493
494
      cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
                 columnwise_scale_inv_cpu_data_.get(), scale_size,
                 cudaMemcpyHostToDevice);
    }
495
496
497
498
  }
}

void Tensor::set_scale(float scale) {
499
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
500
    NVTE_CHECK(scale_cpu_data_);
501
    if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
502
503
504
      *scale_cpu_data_ = scale;
      from_cpu();
    }
505
506
507
508
  }
}

void Tensor::set_scale_inv(float scale_inv) {
509
  if (isFp8Type(dtype()) || isFp4Type(dtype())) {
510
511
512
513
514
515
    if (rowwise_) {
      NVTE_CHECK(rowwise_scale_inv_cpu_data_);
    }
    if (columnwise_) {
      NVTE_CHECK(columnwise_scale_inv_cpu_data_);
    }
516

517
    auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
518
519
    if (rowwise_) {
      auto num_scales = product(rowwise_scale_meta.shape);
520
      if (num_scales == 1) {
521
        rowwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
522
      } else {
523
        std::uniform_int_distribution<uint8_t> dis(0, 127);
524
525
        auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
526
527
528
529
530
531
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
    if (columnwise_) {
      auto num_scales = product(colwise_scale_meta.shape);
532
      if (num_scales == 1) {
533
        columnwise_cpu_scale_inv_ptr<float>()[0] = scale_inv;
534
      } else {
535
        std::uniform_int_distribution<uint8_t> dis(0, 127);
536
537
        auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr<uint8_t>();
        for (size_t i = 0; i < num_scales; i++) {
538
539
540
541
          scale_inv_ptr[i] = dis(gen_);
        }
      }
    }
542
543
544
545
546
    from_cpu();
  }
}

void Tensor::shareFP8Meta(const Tensor &other) {
547
548
  if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
      || isFp4Type(dtype()) && isFp4Type(other.dtype())) {
549
550
    auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
    auto my_rowwise_data = tensor_.get_rowwise_data();
551
    new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast<DType>(my_rowwise_data.dtype),
552
553
554
555
556
557
                                my_rowwise_data.shape);
    auto my_columnwise_data = tensor_.get_columnwise_data();
    new_tensor.set_columnwise_data(my_columnwise_data.data_ptr,
                                   static_cast<DType>(my_columnwise_data.dtype),
                                   my_columnwise_data.shape);
    auto other_amax = other.tensor_.get_amax();
558
    new_tensor.set_amax(other_amax.data_ptr, static_cast<DType>(other_amax.dtype),
559
560
                        other_amax.shape);
    auto other_scale = other.tensor_.get_scale();
561
    new_tensor.set_scale(other_scale.data_ptr, static_cast<DType>(other_scale.dtype),
562
563
564
565
566
567
568
569
570
571
                         other_scale.shape);
    auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv();
    new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr,
                                     static_cast<DType>(other_row_scale_inv.dtype),
                                     other_row_scale_inv.shape);
    auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv();
    new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr,
                                        static_cast<DType>(other_col_scale_inv.dtype),
                                        other_col_scale_inv.shape);
    tensor_ = std::move(new_tensor);
572
573
    to_cpu();
  }
Przemek Tredak's avatar
Przemek Tredak committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
}

using std::to_string;

template <typename T>
std::string to_string(const std::vector<T> &v) {
  std::string s = "[";
  for (const auto x : v) {
    s += to_string(x) + ", ";
  }
  s.pop_back();
  s.pop_back();
  return s + "]";
}

std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
  std::vector<size_t> ret;
  size_t current_i = i;
592
  for (size_t current = shape.ndim - 1; current > 0; --current) {
Przemek Tredak's avatar
Przemek Tredak committed
593
594
595
596
597
598
599
600
    ret.push_back(current_i % shape.data[current]);
    current_i /= shape.data[current];
  }
  ret.push_back(current_i);
  std::reverse(ret.begin(), ret.end());
  return ret;
}

601
602
void compareResults_sequential(const std::string &name, const Tensor &test,
                               const void *ref, const bool rowwise,
603
604
                               double atol, double rtol, bool if_on_gpus,
                               const size_t tolerable_mismatches_limit) {
605
606
607
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
608
609
  size_t mismatches_num = 0;
  int first_mismatch_idx = -1;
Przemek Tredak's avatar
Przemek Tredak committed
610
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
611
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
Przemek Tredak's avatar
Przemek Tredak committed
612
613
    const T *ref_data = reinterpret_cast<const T*>(ref);
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
614
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
615
616
      double t = static_cast<double>(test_data[i]);
      double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
617
618
619
620
#else
      double t = static_cast<double>(static_cast<float>(test_data[i]));
      double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
Przemek Tredak's avatar
Przemek Tredak committed
621
622
623
624
625
626
627
628
629
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
           side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
630
#ifndef __HIP_PLATFORM_AMD__
Przemek Tredak's avatar
Przemek Tredak committed
631
632
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
633
634
635
636
637
638
639
640
#else
        const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
        const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif

#ifdef __HIP_PLATFORM_AMD__
        assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
Przemek Tredak's avatar
Przemek Tredak committed
641
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
642
#endif
Przemek Tredak's avatar
Przemek Tredak committed
643
      }
644
      std::string direction = rowwise ? "rowwise" : "columnwise";
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
      if (assertion) {
        mismatches_num++;
        if (first_mismatch_idx == -1) {
          first_mismatch_idx = i;
        }
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        const double first_mismatch_t = static_cast<double>(test_data[first_mismatch_idx]);
        const double first_mismatch_r = static_cast<double>(ref_data[first_mismatch_idx]);

        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                    << tolerable_mismatches_limit << "." << std::endl
                    << "Error in tensor " << name << " in "
                    << direction << " direction." << std::endl
                     << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape))
                     << " (" << std::to_string(first_mismatch_idx) << "): "
                     << first_mismatch_t << " vs " << first_mismatch_r;
      }
663
664
665
666
667
668
    }
  );
}

template <typename T>
static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data,
669
670
                                  const size_t N, const double atol, const double rtol,
                                  size_t& mismatches) {
671
672
  int first_mismatch_idx = N;

673
674
675
676
677
  #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread)
  {
    size_t thread_mismatches = 0;
    #pragma omp for schedule(static)
    for (size_t i = 0; i < N; ++i) {
yuguo's avatar
yuguo committed
678
#ifndef __HIP_PLATFORM_AMD__
679
680
    double t = static_cast<double>(test_data[i]);
    double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
681
682
683
684
#else
    double t = static_cast<double>(static_cast<float>(test_data[i]));
    double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
685
686
687
688
689
690
691
692
693
694

      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
      /* For Float32 the floating point comparison is enough to error out */
      bool assertion = mismatch && (data_type == DType::kFloat32);
      if (mismatch && !assertion) {
        /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
        const double mean = (t + r) / 2;
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
yuguo's avatar
yuguo committed
695
#ifndef __HIP_PLATFORM_AMD__
696
697
      const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
      const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
yuguo's avatar
yuguo committed
698
699
700
701
702
703
704
#else
      const double cast_mean_p = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_p))));
      const double cast_mean_m = static_cast<double>(static_cast<float>(static_cast<T>(static_cast<float>(mean_m))));
#endif
#ifdef __HIP_PLATFORM_AMD__
      assertion = !(cast_mean_m == std::min<double>(t,r) && cast_mean_p == std::max<double>(t,r));
#else
705
      assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
yuguo's avatar
yuguo committed
706
#endif
707
708
709
710
711
712
713
      }
      if (assertion) {
        if (i < first_mismatch_idx) {
          first_mismatch_idx = i;
        }
        thread_mismatches++;
      }
714
    }
715
    mismatches += thread_mismatches;
716
717
718
719
720
  }
  return first_mismatch_idx;
}

void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref,
721
722
                             const bool rowwise, double atol, double rtol, bool if_on_gpus,
                             const size_t tolerable_mismatches_limit) {
723
724
725
  if (if_on_gpus) test.to_cpu();
  const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape();
  const size_t N = product(shape);
726
  size_t mismatches = 0;
727
728
729
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
    const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
    const T *ref_data = reinterpret_cast<const T*>(ref);
Przemek Tredak's avatar
Przemek Tredak committed
730

731
732
    const size_t i = getFirstMismatchIdx<T>(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches);
    if ((i != N) && (mismatches > tolerable_mismatches_limit)) {
yuguo's avatar
yuguo committed
733
#ifndef __HIP_PLATFORM_AMD__
734
735
      const double t = static_cast<double>(test_data[i]);
      const double r = static_cast<double>(ref_data[i]);
yuguo's avatar
yuguo committed
736
737
738
739
#else
      const double t = static_cast<double>(static_cast<float>(test_data[i]));
      const double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
740
      std::string direction = rowwise ? "rowwise" : "columnwise";
741
742
743
744
745
746
747

      GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of "
                   << tolerable_mismatches_limit << "." << std::endl
                   << "Error in tensor " << name << " in "
                   << direction << " direction." << std::endl
                   << "Mismatch at place " << to_string(unravel(i, shape))
                   << " (" << std::to_string(i) << "): " << t << " vs " << r;
Przemek Tredak's avatar
Przemek Tredak committed
748
749
750
751
    }
  );
}

752
void compareResults(const std::string &name, const Tensor &test, const void *ref,
753
754
                    const bool rowwise, double atol, double rtol, bool if_on_gpus,
                    const size_t tolerable_mismatches_limit) {
755
756
  constexpr bool sequential = false;
  if constexpr (sequential) {
757
    compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
758
  } else {
759
    compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit);
760
761
762
  }
}

763
764
void compareResults(const std::string &name, const float test, const float ref,
                    double atol, double rtol) {
yuguo's avatar
yuguo committed
765
#ifndef __HIP_PLATFORM_AMD__
766
767
  double t = static_cast<double>(test);
  double r = static_cast<double>(ref);
yuguo's avatar
yuguo committed
768
769
770
771
#else
  double t = static_cast<double>(static_cast<float>(test));
  double r = static_cast<double>(static_cast<float>(ref));
#endif
772
773
774
775
776
777
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
                         << "Mismatch: " << t << " vs " << r;

}

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799

void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
                    size_t N, float mismatch_rate_tol) {
  size_t max_mismatches = std::ceil(N * mismatch_rate_tol);
  size_t n_mismatches = 0;
  std::vector<size_t> mismatch_indices;
  for (int i = 0; i < N; i++){
    bool mismatch = test[i] != ref[i];
    if (mismatch){
      n_mismatches++;
      mismatch_indices.push_back(i);
    }
    if (n_mismatches > max_mismatches){
      std::cout << "Error in " << name << std::endl;
      for (auto &index : mismatch_indices)
        std::cout << "Mismatch at (" << index << "):" << static_cast<int>(test[i]) << " vs "
        << static_cast<int>(ref[i]) << std::endl;
      GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol.";
    }
  }
}

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
template <typename T>
struct CastToType;

template <>
struct CastToType<uint8_t> {
  using type = int;
};

template <>
struct CastToType<fp8e4m3> {
  using type = float;
};

template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
                             const size_t row_blocks, const size_t col_blocks, const size_t stride,
                             size_t& mismatches_num, const size_t atol,
                             const double abs_tolerable_mismatches_limit,
                             const double rel_tolerable_mismatches_limit)
819
{
820
821
822
823
  using UpcastType = typename CastToType<T>::type;
  auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);


824
825
826
827
828
829
  const size_t N = row_blocks * col_blocks;
  const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
                                                     std::floor(N * rel_tolerable_mismatches_limit));
  mismatches_num = 0;
  std::vector<int> mismatch_indices;

830
831
832
  for (int i = 0; i < row_blocks; ++i) {
    for (int j = 0; j < col_blocks; ++j) {
      const int idx = i * stride + j;
833
834
835
      float t, r;

      bool assertion = false;
836

837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
      if (std::is_same<T, uint8_t>::value) {
        t = static_cast<float>(test[idx]);
        r = static_cast<float>(ref[idx]);
        assertion = std::abs(t - r) > atol;
      } else {
        t = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&test[idx]));
        r = static_cast<float>(*reinterpret_cast<const fp8e4m3*>(&ref[idx]));
        const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
                              && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
        if (mismatch) {
          /* Check if it is just a failure of round to nearest choosing different
            side of the real value */
          const double mean = (t + r) / 2;
          const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
          const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
          const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
          const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
          assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
        }
      }
      if (assertion) {
858
859
860
861
862
863
864
        mismatches_num++;
        mismatch_indices.push_back(idx);
      }
      if (mismatches_num > tolerable_mismatches_limit) {
        std::cout << "Error in " << name << std::endl;
        for (const int index : mismatch_indices) {
          std::cout << "Mismatch at (" << index << "):"
865
866
                    << static_cast<UpcastType>(test[index]) << " vs "
                    << static_cast<UpcastType>(ref[index]) << std::endl;
867
868
869
870
871
        }
        GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
                     << tolerable_mismatches_limit << ".";
      }
    }
872
873
874
  }
}

875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
// Instantiate templates
template
void compare_scaling_factors<uint8_t>(const std::string &name, const uint8_t *test, const uint8_t *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);

template
void compare_scaling_factors<fp8e4m3>(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
                                      const size_t row_blocks, const size_t col_blocks, const size_t stride,
                                      size_t& mismatches_num, const size_t atol,
                                      const double abs_tolerable_mismatches_limit,
                                      const double rel_tolerable_mismatches_limit);


Przemek Tredak's avatar
Przemek Tredak committed
891
892
893
894
895
896
897
898
899
900
std::pair<double, double> getTolerances(const DType type) {
  switch(type) {
    case DType::kFloat32:
      return {1e-6, 5e-6};
    case DType::kFloat16:
      return {1e-5, 1e-3};
    case DType::kBFloat16:
      return {1e-5, 1e-2};
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
901
    case DType::kFloat8E8M0:
Przemek Tredak's avatar
Przemek Tredak committed
902
903
904
905
906
907
908
      return {1e-2, 1e-2};
    default:
      NVTE_CHECK("Invalid type!");
  }
  return {0, 0};
}

909
910
template <typename T>
void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) {
911
912
913
914
915
916
917
918
919
920
921
922
923
  // Check how many RNG calls are required to generate one uniform random value
  int rng_calls_per_val = 0;
  {
    std::mt19937 gen1 = *gen, gen2 = *gen;
    std::uniform_real_distribution<> dis(-2.0, 1.0);
    const float _ = dis(gen1);
    while (gen2 != gen1) {
      auto _ = gen2();
      ++rng_calls_per_val;
    }
  }

  // Generate uniform random values in parallel
924
925
926
  #pragma omp parallel proc_bind(spread)
  {
    std::mt19937 gen_local = *gen;
927
928
929
930
931
    const int thread_ID = omp_get_thread_num();
    const int threads_num = omp_get_max_threads();
    const int chunk_size = (size + threads_num - 1) / threads_num;
    const int idx_min = chunk_size * thread_ID;
    const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast<int>(size));
932
    gen_local.discard(idx_min * rng_calls_per_val);
933
    std::uniform_real_distribution<> dis(-2.0, 1.0);
934
935

    for (int i = idx_min; i < idx_max; ++i) {
yuguo's avatar
yuguo committed
936
#ifndef __HIP_PLATFORM_AMD__
937
      data[i] = static_cast<T>(dis(gen_local));
yuguo's avatar
yuguo committed
938
939
940
#else
      data[i] = static_cast<T>(static_cast<float>(dis(gen_local)));
#endif
941
942
    }
  }
943
  gen->discard(size * rng_calls_per_val);
944
945
}

946
void fillUniform(Tensor *t) {
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
  if (t->rowwise()) {
    const size_t size = product(t->rowwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->rowwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  } else {
    const size_t size = product(t->columnwise_shape());
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T,
      {
        T *data = t->columnwise_cpu_dptr<T>();
        generate_data_uniformly(data, size, &(t->gen()));
      }
    );
  }
Przemek Tredak's avatar
Przemek Tredak committed
964
  std::uniform_real_distribution<> dis(-2.0, 1.0);
965
966
967
968
969
970
971
972
973
974
975
  t->set_scale_inv(dis(t->gen()));
  t->from_cpu();
}

template<typename InputEncoding, InputsFillCase Case>
void fillCase_special(Tensor *t) {
  const size_t size = product(t->rowwise_shape());

  if constexpr (Case == InputsFillCase::zeros) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
Przemek Tredak's avatar
Przemek Tredak committed
976
      for (size_t i = 0; i < size; ++i) {
977
        data[i] = static_cast<InputType>(0);
Przemek Tredak's avatar
Przemek Tredak committed
978
      }
979
980
981
    });
  } else {
    double minAbs = -2.0;
982
    double maxAbs = 1.0;
983
984
985
986
987
988
989
990
    if constexpr (Case != InputsFillCase::uniform) {
      minAbs = Quantized_Limits<InputEncoding>::ranges[Case];
      maxAbs = Quantized_Limits<InputEncoding>::ranges[Case + 1];
    }
    std::uniform_real_distribution<> dis(minAbs, maxAbs);
    std::uniform_real_distribution<> dis_sign(-1.0, 1.0);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, {
      InputType *data = t->rowwise_cpu_dptr<InputType>();
991
992
993
994
995
      for (size_t idx = 0; idx < size; ++idx) {
        const bool is_negative = (dis_sign(t->gen()) < 0.0);
        double val = dis(t->gen());
        if (is_negative) {
          val = -val;
996
        }
997
        data[idx] = static_cast<InputType>(val);
998
999
1000
1001
      }
    });
  }
  t->set_scale_inv(1.0);
1002
1003
1004
  t->from_cpu();
}

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
template <typename InputEncoding>
void fillCase(Tensor *t, const InputsFillCase fill_case) {
  switch (fill_case) {
    case InputsFillCase::uniform:
        fillCase_special<InputEncoding, InputsFillCase::uniform>(t); break;
    case InputsFillCase::zeros:
        fillCase_special<InputEncoding, InputsFillCase::zeros>(t); break;
    case InputsFillCase::zero_to_minNorm:
        fillCase_special<InputEncoding, InputsFillCase::zero_to_minNorm>(t); break;
    case InputsFillCase::minNorm_to_maxNorm:
        fillCase_special<InputEncoding, InputsFillCase::minNorm_to_maxNorm>(t); break;
    case InputsFillCase::maxNorm_to_inf:
        fillCase_special<InputEncoding, InputsFillCase::maxNorm_to_inf>(t); break;
  }
}

1021
1022
1023
1024
1025
1026
1027
template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
1028
1029
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
1030
1031
1032
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif
1033

1034
1035
void setRandomScale(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
1036
  const float scale = dis(t->gen());
1037
  t->set_scale(scale);
Przemek Tredak's avatar
Przemek Tredak committed
1038
1039
}

1040
1041
1042
1043
1044
1045
void setRandomScaleInv(Tensor *t) {
  std::uniform_real_distribution<> dis(-2.0, 1.0);
  const float scale_inv = dis(t->gen());
  t->set_scale_inv(scale_inv);
}

1046
bool isFp8Type(DType type) {
1047
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
1048
1049
}

1050
bool isFp4Type(DType type) {
maxiao3's avatar
maxiao3 committed
1051
#if FP4_TYPE_SUPPORTED
1052
  return type == DType::kFloat4E2M1;
maxiao3's avatar
maxiao3 committed
1053
1054
1055
#else
  return false;
#endif
1056
1057
}

1058
1059
1060
1061
int32_t getDeviceComputeCapability() {
  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, 0);
  return 10 * deviceProp.major + deviceProp.minor;
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
}

size_t first_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  if (shape.size() == 1) return 1;
  return product(shape, 0, shape.size() - 1);
}

size_t last_dimension(const std::vector<size_t> &shape) {
  if (shape.size() == 0) return 1;
  return shape[shape.size() - 1];
}

std::array<size_t, 4> get_scale_tensor_dims(const size_t rows,
                                            const size_t cols,
                                            const size_t block_size_rows,
                                            const size_t block_size_cols) {
1079
1080
    const bool is_rowwise = (block_size_rows == 1)
                            && ((block_size_cols == 32) || (block_size_cols == 16));
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

    const size_t alignment_Y = is_rowwise
                               ? scale_tensor_alignment_Y_rowwise
                               : scale_tensor_alignment_Y_colwise;
    const size_t alignment_X = is_rowwise
                               ? scale_tensor_alignment_X_rowwise
                               : scale_tensor_alignment_X_colwise;

    const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows);
    const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols);

    const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y);
    const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X);
    return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X};
1095
1096
}

Przemek Tredak's avatar
Przemek Tredak committed
1097
}  // namespace test