cast_transpose.h 3.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_

#include "../common.h"
11
#include "transformer_engine/transformer_engine.h"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

namespace transformer_engine::detail {

void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP,
          ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output,
                          Tensor *dbias, Tensor *workspace, cudaStream_t stream);

template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
          ComputeType (*OP2)(ComputeType, const ParamOP &)>
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
                               cudaStream_t stream);

27
28
29
30
void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
                                         SimpleTensor &scale_inv_t, SimpleTensor &output,
                                         SimpleTensor &output_t, const float epsilon,
                                         const bool return_transpose, const bool pow_2_scale,
31
                                         const SimpleTensor &noop_tensor, cudaStream_t stream);
32

33
34
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
35
  // No rowwise data, skip rowwise quantization
36
37
  NONE,
  // Rowwise data, scales in GEMM format
38
39
40
  ROWWISE_GEMM_READY,
  // Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
  ROWWISE_COMPACT
41
42
43
44
45
};

// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption {
46
  // No columnwise data, skip columnwise quantization
47
48
49
  NONE,
  // Columnwise data transposed from original shape.
  // Scales in GEMM format corresponding to GEMM ingesting transposed column data.
50
51
52
53
54
55
  // On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
  // On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
  COLUMNWISE_GEMM_READY,
  // Columnwise data in original shape
  // Scales in compact format, needs extra processing (padding, transposing) before GEMM
  COLUMNWISE_COMPACT
56
57
};

58
59
60
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
                                         SimpleTensor &scale_inv_t, SimpleTensor &output,
                                         SimpleTensor &output_t, const float epsilon,
61
62
                                         FP8BlockwiseRowwiseOption rowwise_option,
                                         FP8BlockwiseColumnwiseOption columnwise_option,
63
64
                                         const bool pow_2_scale, const SimpleTensor &noop_tensor,
                                         cudaStream_t stream);
65

66
67
68
69
70
71
72
73
void quantize_transpose_vector_blockwise_fp4(
    const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv,
    SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon,
    const bool return_identity, const bool return_transpose, const bool pow2_scale,
    const bool swizzled_scale, const bool use_stochastic_rounding,
    const NVTETensor rng_state_tensor, const bool use_2d_quantization,
    const SimpleTensor &noop_tensor, cudaStream_t stream);

74
75
76
}  // namespace transformer_engine::detail

#endif  // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_