extensions.h 5.34 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
16
#include <transformer_engine/normalization.h>
17
18
#include <transformer_engine/transformer_engine.h>

19
20
21
#include <cassert>
#include <cstddef>
#include <cstdint>
22
#include <iostream>
23
24
#include <stdexcept>
#include <string>
25
#include <vector>
26

27
#include "common/common.h"
28
#include "common/util/logging.h"
29
30
#include "extensions/ffi.h"
#include "extensions/misc.h"
31
#include "extensions/utils.h"
32
#include "transformer_engine/activation.h"
33

34
35
36
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);

37
38
39
namespace transformer_engine {
namespace jax {

40
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
41

42
43
// Activation

44
45
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);

46
47
48
49
50
51
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);

pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
                                                   JAXX_Scaling_Mode scaling_mode, bool is_2x);

52
// Normalization
53
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
54

55
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
56

57
58
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                             DType w_dtype, DType out_dtype,
59
60
                                             NVTE_Norm_Type norm_type,
                                             JAXX_Scaling_Mode scaling_mode,
61
62
                                             bool zero_centered_gamma, float epsilon, int sm_margin,
                                             bool is_training);
63

64
65
66
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                              DType w_dtype, NVTE_Norm_Type norm_type,
                                              bool zero_centered_gamma, int sm_margin);
67

68
// Quantization
69
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
70

71
72
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

73
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
74
75
76
                                               DType in_dtype, DType out_dtype,
                                               JAXX_Scaling_Mode scaling_mode,
                                               QuantizeLayout q_layout);
77

78
// Softmax
79
80
81
82
83
84
85
86
87
88
89
90
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

91
// Attention
92
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
93

94
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
95

96
97
98
99
100
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
101
102
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right);
103

104
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
105
106
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
107
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
108
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
109
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
110
111

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
112
113
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
114
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
115
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
116
117
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
118

119
120
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
121

122
123
124
125
126
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
127

128
129
130
131
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_