extensions.h 4.95 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
16
#include <transformer_engine/normalization.h>
17
18
#include <transformer_engine/transformer_engine.h>

19
20
21
#include <cassert>
#include <cstddef>
#include <cstdint>
22
#include <iostream>
23
24
#include <stdexcept>
#include <string>
25
#include <vector>
26

27
#include "common/common.h"
28
#include "common/util/logging.h"
29
30
31
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "transformer_engine/activation.h"
32
33
#include "utils.h"

34
35
36
namespace transformer_engine {
namespace jax {

37
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
38

39
40
// Activation

41
42
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);

43
// Normalization
44
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
45

46
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
47

48
49
50
51
52
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                             DType w_dtype, DType out_dtype,
                                             NVTE_Norm_Type norm_type, int scaling_mode,
                                             bool zero_centered_gamma, float epsilon, int sm_margin,
                                             bool is_training);
53

54
55
56
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
                                              DType w_dtype, NVTE_Norm_Type norm_type,
                                              bool zero_centered_gamma, int sm_margin);
57

58
// Quantization
59
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
60

61
62
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

63
64
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                               DType in_dtype, DType out_dtype);
65

66
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
67

68
69
70
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
                                                   int scaling_mode, bool is_2x);
71

72
// Softmax
73
74
75
76
77
78
79
80
81
82
83
84
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

85
// Attention
86
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
87

88
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
89

90
91
92
93
94
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
95
96
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right);
97

98
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
99
100
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
101
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
102
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
103
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
104
105

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
106
107
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
108
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
109
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
110
111
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
112

113
114
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
115

116
117
118
119
120
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
121

122
123
124
125
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_