extensions.h 6.8 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Phuong Nguyen's avatar
Phuong Nguyen committed
16
#include <transformer_engine/comm_gemm_overlap.h>
17
#include <transformer_engine/normalization.h>
18
19
#include <transformer_engine/transformer_engine.h>

20
21
22
#include <cassert>
#include <cstddef>
#include <cstdint>
23
#include <iostream>
24
25
#include <stdexcept>
#include <string>
26
#include <vector>
27

28
#include "common/common.h"
29
#include "common/util/logging.h"
30
31
#include "extensions/ffi.h"
#include "extensions/misc.h"
32
#include "extensions/utils.h"
33
#include "transformer_engine/activation.h"
34
#include "transformer_engine/multi_stream.h"
35

36
37
38
namespace transformer_engine {
namespace jax {

39
40
41
42
43
44
45
46
47
struct ClampedSwigluConfig {
  float limit;
  float alpha;
};

struct ActivationConfig {
  ClampedSwigluConfig clamped_swiglu;
};

48
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
49

50
51
// Activation

52
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
53
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler);
54

55
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
56
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
57
58
59

pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
60
61
                                                   JAXX_Scaling_Mode scaling_mode,
                                                   JAXX_Quantize_Layout quantize_layout);
62

63
// Normalization
64
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
65
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
66

67
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler);
68
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
69

70
71
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                             DType w_dtype, DType out_dtype,
72
73
                                             NVTE_Norm_Type norm_type,
                                             JAXX_Scaling_Mode scaling_mode,
74
75
                                             bool zero_centered_gamma, float epsilon, int sm_margin,
                                             bool is_training);
76

77
78
79
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                              DType w_dtype, NVTE_Norm_Type norm_type,
                                              bool zero_centered_gamma, int sm_margin);
80

81
// Quantization
82
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
83

84
85
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);

86
87
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

88
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
89
                                               DType in_dtype, DType out_dtype, DType scale_dtype,
90
                                               JAXX_Scaling_Mode scaling_mode,
91
                                               JAXX_Quantize_Layout quantize_layout);
92

93
// Softmax
94
95
96
97
98
99
100
101
102
103
104
105
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

106
// Attention
107
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
108

109
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
110

111
112
113
114
115
116
NVTE_Fused_Attn_Backend GetFusedAttnBackend(
    bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
    float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
    size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
    int64_t window_size_right);
117

118
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
119
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
120
121
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
    size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
122
123
124
    NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
    DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
125
126

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
127
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
128
129
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
    size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
130
131
132
    NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
    DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
    int64_t window_size_left, int64_t window_size_right);
133

Alp Dener's avatar
Alp Dener committed
134
135
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
Phuong Nguyen's avatar
Phuong Nguyen committed
136
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
Alp Dener's avatar
Alp Dener committed
137

138
// Grouped GEMM
139
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
140
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
141

142
143
144
145
// Amax
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);

146
147
148
149
150
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
151

152
153
154
}  // namespace jax
}  // namespace transformer_engine

155
156
157
158
159
160
161
162
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
                                      ::xla::ffi::StructMember<float>("limit"),
                                      ::xla::ffi::StructMember<float>("alpha"));

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
    transformer_engine::jax::ActivationConfig,
    ::xla::ffi::StructMember<transformer_engine::jax::ClampedSwigluConfig>("clamped_swiglu"));

Phuong Nguyen's avatar
Phuong Nguyen committed
163
164
165
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
166
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout);
Phuong Nguyen's avatar
Phuong Nguyen committed
167

168
#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_