pybind.cpp 10.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
9
10
11
12
13

namespace transformer_engine {
namespace jax {

template <typename T>
pybind11::capsule EncapsulateFunction(T *fn) {
14
  return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
15
16
}

17
18
19
20
21
22
23
template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) {
  static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
                "Encapsulated function must be an XLA FFI handler");
  return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}

24
pybind11::dict Registrations() {
25
26
27
  pybind11::dict dict;
  dict["te_transpose"] = EncapsulateFunction(Transpose);
  dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
28

29
30
31
32
33
34
  dict["te_act_lu"] = EncapsulateFunction(ActLu);
  dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8);
  dict["te_dact_lu"] = EncapsulateFunction(DActLu);
  dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose);
  dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose);
  dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose);
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
  dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
  dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
  dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
  dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward);
  dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8);
  dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward);
  dict["te_quantize"] = EncapsulateFunction(Quantize);
  dict["te_dequantize"] = EncapsulateFunction(Dequantize);
  dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward);
  dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward);
  dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward);
  dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward);
  dict["te_scaled_upper_triang_masked_softmax_forward"] =
      EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
  dict["te_scaled_upper_triang_masked_softmax_backward"] =
      EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
  dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
  dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
54

55
  // Transpose
56
  dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
57
  dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
58
  dict["te_dbias_cast_transpose_ffi"] = EncapsulateFFI(DBiasCastTransposeHandler);
59
60

  // Activation
61
  dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
62
  dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
63
  dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
64
65
  dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler);
  dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler);
66
67

  // Quantization
68
  dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
69
70
71
  dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);

  // Softmax
72
73
74
  dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
  dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
  dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
75
  dict["te_scaled_masked_softmax_backward_ffi"] =
76
      EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
77
  dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
78
      EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
79
  dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
80
      EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
81
82

  // Normalization
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  dict["te_layernorm_forward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
  dict["te_layernorm_forward_fp8_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
  dict["te_layernorm_backward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
  dict["te_rmsnorm_forward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
  dict["te_rmsnorm_forward_fp8_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
  dict["te_rmsnorm_backward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));
101
102

  // Attention
103
104
105
106
  pybind11::dict fused_attn_forward_ffi;
  fused_attn_forward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
  fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
  dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
107

108
109
110
111
112
  pybind11::dict fused_attn_backward_ffi;
  fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
  fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler);
  dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi;

113
  return dict;
114
115
116
}

PYBIND11_MODULE(transformer_engine_jax, m) {
117
118
119
120
121
122
123
124
125
126
127
  m.def("registrations", &Registrations);
  m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(),
        pybind11::arg(), pybind11::arg("act_num") = 0);
  m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, pybind11::arg(),
        pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg(),
        pybind11::arg("act_num") = 0);
  m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
  m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
  m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
  m.def("get_fused_attn_backend", &GetFusedAttnBackend);
  m.def("get_cuda_version", &GetCudaRuntimeVersion);
128
  m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
129
130
131
132
133
134
135
136
  m.def("get_device_compute_capability", &GetDeviceComputeCapability);
  m.def("get_cublasLt_version", &cublasLtGetVersion);
  m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes);
  m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes);
  m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
  m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
  m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
  m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
137
  m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
138

139
140
141
142
143
144
145
146
147
  pybind11::enum_<DType>(m, "DType", pybind11::module_local())
      .value("kByte", DType::kByte)
      .value("kInt32", DType::kInt32)
      .value("kInt64", DType::kInt64)
      .value("kFloat32", DType::kFloat32)
      .value("kFloat16", DType::kFloat16)
      .value("kBFloat16", DType::kBFloat16)
      .value("kFloat8E4M3", DType::kFloat8E4M3)
      .value("kFloat8E5M2", DType::kFloat8E5M2);
148

149
150
151
152
  pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
      .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
      .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
      .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
153

154
155
156
157
  pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
      .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
      .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
      .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
158
159
160
161
      .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
      .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
      .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
             NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
162

163
164
165
  pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
      .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
      .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
166
167
168
169
170
171
172
173
174
      .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
      .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
      .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
      .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);

  pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
      .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
      .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
      .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
175

176
177
178
179
180
181
182
183
184
185
  pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
      .value("GELU", NVTE_Activation_Type::GELU)
      .value("GEGLU", NVTE_Activation_Type::GEGLU)
      .value("SILU", NVTE_Activation_Type::SILU)
      .value("SWIGLU", NVTE_Activation_Type::SWIGLU)
      .value("RELU", NVTE_Activation_Type::RELU)
      .value("REGLU", NVTE_Activation_Type::REGLU)
      .value("QGELU", NVTE_Activation_Type::QGELU)
      .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
      .value("SRELU", NVTE_Activation_Type::SRELU)
186
187
      .value("SREGLU", NVTE_Activation_Type::SREGLU)
      .export_values();
188

189
190
191
192
193
  pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
      .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
      .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
      .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
      .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
194
195
196
197
}

}  // namespace jax
}  // namespace transformer_engine