test_cast_mxfp8.cu 29.2 KB
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

wenjh's avatar
wenjh committed
7
8
9
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"

using namespace transformer_engine;
using namespace test;

namespace {

enum ProcessingMethod {
    CAST_ONLY,
    CAST_DBIAS,
    CAST_DBIAS_DACT,
    CAST_DACT,
    CAST_ACT
};

enum ActivationType {
    Identity,
    GeLU,
    SiLU,
    ReLU,
    QGeLU,
    SReLU
};

42
43
44
45
46
template <typename InputType, typename OutputType>
void compute_ref(const ProcessingMethod processing_method,
                 float (*OP)(const float),
                 const bool rowwise,
                 const bool colwise,
47
48
                 const InputType* input,
                 const InputType* grad,
49
50
51
52
53
54
55
56
57
                 OutputType* output_rowwise,
                 OutputType* output_colwise,
                 fp8e8m0* output_scales_rowwise,
                 fp8e8m0* output_scales_colwise,
                 InputType* output_dbias,
                 const size_t rows,
                 const size_t cols,
                 const size_t scales_stride_rowwise,
                 const size_t scales_stride_colwise)
58
{
59
60
    const size_t tile_size_Y = 32;
    const size_t tile_size_X = 32;
61
62
63
    const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
    const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;

64
    std::vector<float> output_dbias_fp32(cols, 0);
65
66
    #pragma omp parallel proc_bind(spread)
    {
67
68
69
        // Buffers to cache intermediate computations
        std::vector<float> cache_buffer(tile_size_Y * tile_size_X);

70
71
72
73
74
75
76
77
        std::vector<float> thread_dbias(cols, 0);
        #pragma omp for schedule(static)
        for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) {
            const size_t tile_Y = t / tiles_num_X;
            const size_t tile_X = t % tiles_num_X;
            const size_t tile_offset_Y = tile_Y * tile_size_Y;
            const size_t tile_offset_X = tile_X * tile_size_X;

78
79
80
81
82
83
84
85
86
            const size_t i_min = tile_offset_Y;
            const size_t i_max = std::min(i_min + tile_size_Y, rows);

            const size_t j_min = tile_offset_X;
            const size_t j_max = std::min(j_min + tile_size_X, cols);

            // Cache computations
            for (size_t i = i_min; i < i_max; ++i) {
                for (size_t j = j_min; j < j_max; ++j) {
87

88
89
                    const size_t idx = i * cols + j;
                    const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

                    float elt = static_cast<float>(input[idx]);
                    if (processing_method == ProcessingMethod::CAST_DBIAS) {
                        // grad is the input
                        elt = static_cast<float>(grad[idx]);
                    }
                    if (processing_method != ProcessingMethod::CAST_ONLY
                        && processing_method != ProcessingMethod::CAST_DBIAS) {
                        elt = OP(elt);
                    }
                    if (processing_method == ProcessingMethod::CAST_DACT ||
                        processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
                        elt *= static_cast<float>(grad[idx]);
                    }
                    thread_dbias[j] += elt;

                    // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
                    elt = static_cast<float>(static_cast<InputType>(elt));

                    cache_buffer[cache_idx] = elt;
wenjh's avatar
wenjh committed
110
                    if (std::isinf(elt) || std::isnan(elt)) {
111
112
113
114
115
116
117
118
119
120
                        continue;
                    }
                }
            }

            if (rowwise) {
                for (size_t i = i_min; i < i_max; ++i) {
                    float block_amax = 0.0f;

                    for (size_t j = j_min; j < j_max; ++j) {
121
                        const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
122
123
124
125
                        block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
                    }

                    const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
126
                    const size_t scale_idx = i * scales_stride_rowwise + tile_X;
127
128
129
130
                    output_scales_rowwise[scale_idx] = biased_exponent;
                    const float scale_reciprocal = exp2f_rcp(biased_exponent);

                    for (size_t j = j_min; j < j_max; ++j) {
131
132
                        const size_t idx = i * cols + j;
                        const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
133
134
135
136
137
138
139
140
141
                        output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
                    }
                }
            }
            if (colwise) {
                for (size_t j = j_min; j < j_max; ++j) {
                    float block_amax = 0.0f;

                    for (size_t i = i_min; i < i_max; ++i) {
142
                        const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
143
144
145
146
                        block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
                    }

                    const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
147
                    const size_t scale_idx = tile_Y * scales_stride_colwise + j;
148
149
150
151
                    output_scales_colwise[scale_idx] = biased_exponent;
                    const float scale_reciprocal = exp2f_rcp(biased_exponent);

                    for (size_t i = i_min; i < i_max; ++i) {
152
153
                        const size_t idx = i * cols + j;
                        const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
154
155
                        output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
                    }
156
157
158
159
160
161
162
163
                }
            }
        }
        #pragma omp critical
        {
            for (size_t j = 0; j < cols; ++j) {
                output_dbias_fp32[j] += thread_dbias[j];
            }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        }
    }
    for (size_t j = 0; j < cols; ++j) {
        output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
    }
}

/**
 * Scaling along single dimension (either rows or columns)
 * Produces one set of output data and the corresponding data of the fused operation (dbias):
 * 1) Scaled rows + row-wise scaling factors
 *       OR
 * 2) Scaled columns + column-wise scaling factors
 */

179
template <typename InputType, typename OutputType>
180
void performTest_x1(const ProcessingMethod processing_method,
181
                    float (*OP)(const float),
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                    const std::vector<size_t>& shape,
                    const bool rowwise,
                    const bool colwise,
                    InputsFillCase fill_case) {
    using namespace test;
    using EncodingType = fp32;
    DType itype = TypeInfo<InputType>::dtype;
    DType otype = TypeInfo<OutputType>::dtype;

    const size_t rows = first_dimension(shape);
    const size_t cols = last_dimension(shape);

    if (shape.size() < 2 && colwise) {
      GTEST_SKIP();
    }

    const size_t block_size_rows = rowwise ? 1 : 32;
    const size_t block_size_cols = colwise ? 1 : 32;

    const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
                                                                  block_size_cols);

    const size_t unpadded_blocks_Y = scale_dims[0];
    const size_t unpadded_blocks_X = scale_dims[1];
    const size_t blocks_Y = scale_dims[2];
    const size_t blocks_X = scale_dims[3];
    const size_t scales_stride = blocks_X;

    Tensor input("input", shape, itype);
    Tensor grad("grad", shape, itype);
    Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);
213
    Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(rows * cols);
    std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);
    std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);

    fillCase<EncodingType>(&input, fill_case);
    fillUniform(&grad);

    Tensor workspace;
    switch (processing_method) {
        case ProcessingMethod::CAST_ONLY: {
            nvte_quantize(input.data(), output_c.data(), 0);
            break;
        }
        case ProcessingMethod::CAST_DBIAS: {
            nvte_quantize_dbias(grad.data(),
                                output_c.data(),
                                output_dbias.data(),
                                workspace.data(),
                                0);
            workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());

            nvte_quantize_dbias(grad.data(),
                                output_c.data(),
                                output_dbias.data(),
                                workspace.data(),
                                0);
            break;
        }
        case ProcessingMethod::CAST_DBIAS_DACT: {
244
245
246
247
248
249
250
251
252
253
254
255
            auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
            if (OP == &dsilu)       { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
            else if (OP == &drelu)  { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
            else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
            else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }

            nvte_quantize_dbias_dact(grad.data(),
                                     input.data(),
                                     output_c.data(),
                                     output_dbias.data(),
                                     workspace.data(),
                                     0);
256
257
            workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());

258
259
260
261
262
263
            nvte_quantize_dbias_dact(grad.data(),
                                     input.data(),
                                     output_c.data(),
                                     output_dbias.data(),
                                     workspace.data(),
                                     0);
264
265
266
            break;
        }
        case ProcessingMethod::CAST_DACT: {
267
268
269
270
271
272
273
            auto nvte_dact = &nvte_dgelu;
            if (OP == &dsilu)       { nvte_dact = &nvte_dsilu; }
            else if (OP == &drelu)  { nvte_dact = &nvte_drelu; }
            else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
            else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }

            nvte_dact(grad.data(), input.data(), output_c.data(), 0);
274
275
276
            break;
        }
        case ProcessingMethod::CAST_ACT: {
277
278
279
280
281
282
283
            auto nvte_act = &nvte_gelu;
            if (OP == &silu)       { nvte_act = &nvte_silu; }
            else if (OP == &relu)  { nvte_act = &nvte_relu; }
            else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
            else if (OP == &srelu) { nvte_act = &nvte_srelu; }

            nvte_act(input.data(), output_c.data(), 0);
284
285
286
287
288
289
290
291
            break;
        }
    }

    cudaDeviceSynchronize();
    auto err = cudaGetLastError();
    ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    compute_ref<InputType, OutputType>(processing_method,
                                       OP,
                                       rowwise,
                                       colwise,
                                       input.rowwise_cpu_dptr<InputType>(),
                                       grad.rowwise_cpu_dptr<InputType>(),
                                       ref_output_c.get(),
                                       ref_output_c.get(),
                                       ref_output_scales.get(),
                                       ref_output_scales.get(),
                                       ref_output_dbias.get(),
                                       rows,
                                       cols,
                                       scales_stride,
                                       scales_stride);
307
308
309
310
311

    const uint8_t * const gpu_scales_ptr = rowwise
                                           ? output_c.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
                                           : output_c.columnwise_cpu_scale_inv_ptr<fp8e8m0>();

312
313
314
315
316
    const size_t scale_diff_abs_tolerance = 0;
    const double abs_tolerable_mismatches_limit = 0.0;
    const double rel_tolerable_mismatches_limit = 0.0;

    size_t mismatches_scales = 0;
317
318
319
320
321
322
323

    compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
                            unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
                            mismatches_scales,
                            scale_diff_abs_tolerance,
                            abs_tolerable_mismatches_limit,
                            rel_tolerable_mismatches_limit);
324
325
326
327

    const size_t mismatches_elts = 32 * mismatches_scales;
    auto [atol, rtol] = getTolerances(otype);
    compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
328

329
330
331
    if (processing_method == ProcessingMethod::CAST_DBIAS
        || processing_method == ProcessingMethod::CAST_DBIAS_DACT)
    {
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        auto [atol_dbias, rtol_dbias] = getTolerances(itype);
        if (itype == DType::kFloat32) {
            atol_dbias = 1e-4;
            rtol_dbias *= sqrt(static_cast<double>(rows)) ;
        } else {
            rtol_dbias *= 4;
        }
        compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
    }
}

/**
 * Scaling along both dimensions (rows and columns)
 * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias):
 * 1) Scaled rows + row-wise scaling factors
 *      AND
 * 2) Scaled columns + column-wise scaling factors
 */
350
template <typename InputType, typename OutputType>
351
void performTest_x2(const ProcessingMethod processing_method,
352
                    float (*OP)(const float),
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
                    const std::vector<size_t>& shape,
                    const size_t block_size_rows,
                    const size_t block_size_cols,
                    InputsFillCase fill_case) {
    using namespace test;
    using EncodingType = fp32;
    DType itype = TypeInfo<InputType>::dtype;
    DType otype = TypeInfo<OutputType>::dtype;

    if (shape.size() < 2) {
      GTEST_SKIP();
    }

    const size_t rows = first_dimension(shape);
    const size_t cols = last_dimension(shape);

    const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
    const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);

    const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
    const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
    const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
    const size_t blocks_X_rowwise = scale_dims_rowwise[3];
    const size_t scales_stride_rowwise = blocks_X_rowwise;

    const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
    const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
    const size_t blocks_Y_colwise = scale_dims_colwise[2];
    const size_t blocks_X_colwise = scale_dims_colwise[3];
    const size_t scales_stride_colwise = blocks_X_colwise;

    Tensor input("input", shape, itype);
    Tensor grad("grad", shape, itype);
    Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING);
387
    Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

    std::unique_ptr<OutputType[]> ref_output_c_rowwise = std::make_unique<OutputType[]>(rows * cols);
    std::unique_ptr<OutputType[]> ref_output_c_colwise = std::make_unique<OutputType[]>(rows * cols);
    std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
    std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);
    std::unique_ptr<InputType[]> ref_output_dbias = std::make_unique<InputType[]>(cols);

    fillCase<EncodingType>(&input, fill_case);
    fillUniform(&grad);

    Tensor workspace;
    switch (processing_method) {
        case ProcessingMethod::CAST_ONLY: {
            nvte_quantize(input.data(), output.data(), 0);
            break;
        }
        case ProcessingMethod::CAST_DBIAS: {
            nvte_quantize_dbias(grad.data(),
                                output.data(),
                                output_dbias.data(),
                                workspace.data(),
                                0);
            workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());

            nvte_quantize_dbias(grad.data(),
                                output.data(),
                                output_dbias.data(),
                                workspace.data(),
                                0);
            break;
        }
        case ProcessingMethod::CAST_DBIAS_DACT: {
420
421
422
423
424
425
426
427
428
429
430
431
            auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu;
            if (OP == &dsilu)       { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; }
            else if (OP == &drelu)  { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; }
            else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; }
            else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; }

            nvte_quantize_dbias_dact(grad.data(),
                                     input.data(),
                                     output.data(),
                                     output_dbias.data(),
                                     workspace.data(),
                                     0);
432
433
            workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());

434
435
436
437
438
439
            nvte_quantize_dbias_dact(grad.data(),
                                     input.data(),
                                     output.data(),
                                     output_dbias.data(),
                                     workspace.data(),
                                     0);
440
441
442
            break;
        }
        case ProcessingMethod::CAST_DACT: {
443
444
445
446
447
448
449
            auto nvte_dact = &nvte_dgelu;
            if (OP == &dsilu)       { nvte_dact = &nvte_dsilu; }
            else if (OP == &drelu)  { nvte_dact = &nvte_drelu; }
            else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; }
            else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; }

            nvte_dact(grad.data(), input.data(), output.data(), 0);
450
451
452
            break;
        }
        case ProcessingMethod::CAST_ACT: {
453
454
455
456
457
458
459
            auto nvte_act = &nvte_gelu;
            if (OP == &silu)       { nvte_act = &nvte_silu; }
            else if (OP == &relu)  { nvte_act = &nvte_relu; }
            else if (OP == &qgelu) { nvte_act = &nvte_qgelu; }
            else if (OP == &srelu) { nvte_act = &nvte_srelu; }

            nvte_act(input.data(), output.data(), 0);
460
461
462
463
464
465
466
467
            break;
        }
    }

    cudaDeviceSynchronize();
    auto err = cudaGetLastError();
    ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    compute_ref<InputType, OutputType>(processing_method,
                                       OP,
                                       true,
                                       true,
                                       input.rowwise_cpu_dptr<InputType>(),
                                       grad.rowwise_cpu_dptr<InputType>(),
                                       ref_output_c_rowwise.get(),
                                       ref_output_c_colwise.get(),
                                       ref_scales_rowwise.get(),
                                       ref_scales_colwise.get(),
                                       ref_output_dbias.get(),
                                       rows,
                                       cols,
                                       scales_stride_rowwise,
                                       scales_stride_colwise);

    const size_t scale_diff_abs_tolerance = 0;
    const double abs_tolerable_mismatches_limit = 0.0;
    const double rel_tolerable_mismatches_limit = 0.0;

    size_t mismatches_scales_rowwise = 0;
489
490
491
492
493
494
495
    compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
                            ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
                            unpadded_blocks_X_rowwise, scales_stride_rowwise,
                            mismatches_scales_rowwise,
                            scale_diff_abs_tolerance,
                            abs_tolerable_mismatches_limit,
                            rel_tolerable_mismatches_limit);
496
497

    size_t mismatches_scales_colwise = 0;
498
499
500
501
502
503
504
    compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
                            ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
                            unpadded_blocks_X_colwise, scales_stride_colwise,
                            mismatches_scales_colwise,
                            scale_diff_abs_tolerance,
                            abs_tolerable_mismatches_limit,
                            rel_tolerable_mismatches_limit);
505
506
507
508
509
510
511

    const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise;
    const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise;

    auto [atol, rtol] = getTolerances(otype);
    compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise);
    compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise);
512

513
514
515
    if (processing_method == ProcessingMethod::CAST_DBIAS
        || processing_method == ProcessingMethod::CAST_DBIAS_DACT)
    {
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        auto [atol_dbias, rtol_dbias] = getTolerances(itype);
        if (itype == DType::kFloat32) {
            atol_dbias = 1e-4;
            rtol_dbias *= sqrt(static_cast<double>(rows)) ;
        } else {
            rtol_dbias *= 4;
        }
        compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);
    }
}

std::vector<std::vector<size_t>> matrix_sizes = {
    {1, 16},
    {16, 48},
    {65, 96},
    {128, 128},
    {256, 256},
    {993, 512},
534
535
536
537
    {511, 6144},
    {8192, 128},
    {2048, 160},
    {577, 1632},
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
    {1024},
    {8, 32, 1024},
    {16, 8, 4, 512},
};

std::vector<std::pair<size_t, size_t>> block_sizes = {
    {1, 32},
    {32, 1},
    {32, 32},
};

std::vector<InputsFillCase> input_scenarios = {
    InputsFillCase::uniform,
    // InputsFillCase::zeros,
    // InputsFillCase::zero_to_minNorm,
    // InputsFillCase::minNorm_to_maxNorm,
    // InputsFillCase::maxNorm_to_inf
};

std::vector<ProcessingMethod> processing_methods = {
    ProcessingMethod::CAST_ONLY,
    ProcessingMethod::CAST_DBIAS,
    ProcessingMethod::CAST_DBIAS_DACT,
    ProcessingMethod::CAST_DACT,
    ProcessingMethod::CAST_ACT,
};

// Only GeLU activation tests are supported
std::vector<ActivationType> Activation_types = {
    ActivationType::Identity,
    ActivationType::GeLU,
    // ActivationType::SiLU,
    // ActivationType::ReLU,
    // ActivationType::QGeLU,
    // ActivationType::SReLU,
};

}  // namespace

class FusedCastMXFP8TestSuite : public ::testing::TestWithParam
    <std::tuple<ProcessingMethod,
                ActivationType,
                std::vector<size_t>,
                std::pair<size_t, size_t>,
                transformer_engine::DType,
                transformer_engine::DType,
                InputsFillCase>> {};

TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
    // Skip tests for pre-Blackwell architectures
    if (getDeviceComputeCapability() < blackwellComputeCapability) {
        GTEST_SKIP();
    }

    using namespace transformer_engine;
    using namespace test;

    const ProcessingMethod processing_method = std::get<0>(GetParam());
    const ActivationType Act_type = std::get<1>(GetParam());
    const auto matrix_size = std::get<2>(GetParam());
    const auto block_size = std::get<3>(GetParam());
    const DType input_type = std::get<4>(GetParam());
    const DType output_type = std::get<5>(GetParam());
    const InputsFillCase fill_case = std::get<6>(GetParam());

    // Skips non Act tests if the Activation type is not an identity
    if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
        && Act_type != ActivationType::Identity) {
        GTEST_SKIP();
    }
    // Skips Act tests if the Activation is an identity
    if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
        || processing_method == ProcessingMethod::CAST_DACT
        || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
        GTEST_SKIP();
    }

    const bool rowwise = block_size.second != 1;
    const bool colwise = block_size.first != 1;
    if (processing_method == ProcessingMethod::CAST_ACT) {
        // Forward activations
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        auto OP = &identity;
        switch (Act_type) {
            case ActivationType::GeLU: OP = &gelu; break;
            case ActivationType::SiLU: OP = &silu; break;
            case ActivationType::ReLU: OP = &relu; break;
            case ActivationType::QGeLU: OP = &qgelu; break;
            case ActivationType::SReLU: OP = &srelu; break;
        }

        TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
            TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
                if (block_size.first == 1 || block_size.second == 1) {
                    performTest_x1<InputType, OutputType>(
                        processing_method, OP, matrix_size,
                        rowwise, colwise, fill_case);
                } else {
                    performTest_x2<InputType, OutputType>(
                        processing_method, OP, matrix_size,
                        block_size.first, block_size.second, fill_case);
                }
639
640
641
            );
        );
    } else {
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        auto OP = &identity;
        switch (Act_type) {
            case ActivationType::GeLU: OP = &dgelu; break;
            case ActivationType::SiLU: OP = &dsilu; break;
            case ActivationType::ReLU: OP = &drelu; break;
            case ActivationType::QGeLU: OP = &dqgelu; break;
            case ActivationType::SReLU: OP = &dsrelu; break;
        }
        TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
            TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
                if (block_size.first == 1 || block_size.second == 1) {
                    performTest_x1<InputType, OutputType>(
                        processing_method, OP, matrix_size,
                        rowwise, colwise, fill_case);
                } else {
                    performTest_x2<InputType, OutputType>(
                        processing_method, OP, matrix_size,
                        block_size.first, block_size.second, fill_case);
                }
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
            );
        );
    }
}

std::string to_string(const ProcessingMethod method) {
    switch (method) {
        case ProcessingMethod::CAST_ONLY:       return "CAST_ONLY";
        case ProcessingMethod::CAST_DBIAS:      return "CAST_DBIAS";
        case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
        case ProcessingMethod::CAST_DACT:       return "CAST_DACT";
        case ProcessingMethod::CAST_ACT:        return "CAST_ACT";
        default: return "";
    }
}

std::string to_string(const ActivationType Act_type) {
    switch (Act_type) {
        case ActivationType::Identity:  return "Identity";
        case ActivationType::GeLU:      return "GeLU";
        case ActivationType::SiLU:      return "SiLU";
        case ActivationType::ReLU:      return "ReLU";
        case ActivationType::QGeLU:     return "QGeLU";
        case ActivationType::SReLU:     return "SReLU";
        default: return "";
    }
}

INSTANTIATE_TEST_SUITE_P(
    OperatorTest,
    FusedCastMXFP8TestSuite,
    ::testing::Combine(
        ::testing::ValuesIn(processing_methods),
        ::testing::ValuesIn(Activation_types),
        ::testing::ValuesIn(matrix_sizes),
        ::testing::ValuesIn(block_sizes),
        ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
        ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
        ::testing::ValuesIn(input_scenarios)),
    [](const testing::TestParamInfo<FusedCastMXFP8TestSuite::ParamType>& info) {
        std::string name = to_string(std::get<0>(info.param)) + "X" +
                           to_string(std::get<1>(info.param));
      const auto& shape = std::get<2>(info.param);
      for ( const auto& s: shape) {
        name += "X" + std::to_string(s);
      }
      name += "X" + std::to_string(std::get<3>(info.param).first) +
              "X" + std::to_string(std::get<3>(info.param).second) +
              "X" + test::typeName(std::get<4>(info.param)) +
              "X" + test::typeName(std::get<5>(info.param)) +
              "X" + test::caseName(std::get<6>(info.param));
        return name;
    });