extensions.h 5.45 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
16
#include <transformer_engine/normalization.h>
17
18
#include <transformer_engine/transformer_engine.h>

19
20
21
#include <cassert>
#include <cstddef>
#include <cstdint>
22
#include <iostream>
23
24
#include <stdexcept>
#include <string>
25
#include <vector>
26

27
#include "common/common.h"
28
#include "common/util/logging.h"
29
30
#include "extensions/ffi.h"
#include "extensions/misc.h"
31
#include "extensions/utils.h"
32
#include "transformer_engine/activation.h"
33
#include "transformer_engine/multi_stream.h"
34

35
36
37
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);

38
39
40
namespace transformer_engine {
namespace jax {

41
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
42

43
44
// Activation

45
46
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);

47
48
49
50
51
52
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);

pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
                                                   JAXX_Scaling_Mode scaling_mode, bool is_2x);

53
// Normalization
54
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
55

56
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
57

58
59
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                             DType w_dtype, DType out_dtype,
60
61
                                             NVTE_Norm_Type norm_type,
                                             JAXX_Scaling_Mode scaling_mode,
62
63
                                             bool zero_centered_gamma, float epsilon, int sm_margin,
                                             bool is_training);
64

65
66
67
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                              DType w_dtype, NVTE_Norm_Type norm_type,
                                              bool zero_centered_gamma, int sm_margin);
68

69
// Quantization
70
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
71

72
73
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);

74
75
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

76
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
77
78
79
                                               DType in_dtype, DType out_dtype,
                                               JAXX_Scaling_Mode scaling_mode,
                                               QuantizeLayout q_layout);
80

81
// Softmax
82
83
84
85
86
87
88
89
90
91
92
93
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

94
// Attention
95
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
96

97
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
98

99
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
100
101
102
103
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
104
105
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right);
106

107
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
108
109
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
110
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
111
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
112
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
113
114

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
115
116
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
117
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
118
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
119
120
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
121

122
123
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
124

125
126
127
128
129
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
130

131
132
133
134
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_