pybind.cpp 9.75 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "../extensions.h"
Phuong Nguyen's avatar
Phuong Nguyen committed
8
9
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
10
11
12
13

namespace transformer_engine {
namespace jax {

14
15
16
17
18
19
20
template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) {
  static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
                "Encapsulated function must be an XLA FFI handler");
  return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}

21
pybind11::dict Registrations() {
22
  pybind11::dict dict;
23
24

  // Activation
25
26
27
28
29
30
  dict["te_act_lu_ffi"] =
      pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
                     pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
  dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
      pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
      pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
31
32

  // Quantization
33
  dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
34
  dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler);
35
36
37
  dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);

  // Softmax
38
39
40
  dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
  dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
  dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
41
  dict["te_scaled_masked_softmax_backward_ffi"] =
42
      EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
43
  dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
44
      EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
45
  dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
46
      EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
47
48

  // Normalization
49
  dict["te_norm_forward_ffi"] =
50
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
51
                     pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
52
53
                     pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
  dict["te_norm_backward_ffi"] =
54
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
55
                     pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
56
                     pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
57
58

  // Attention
59
60
61
62
63
64
  dict["te_fused_attn_forward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler));
  dict["te_fused_attn_backward_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
65

Alp Dener's avatar
Alp Dener committed
66
67
  // GEMM
  dict["te_gemm_ffi"] =
Phuong Nguyen's avatar
Phuong Nguyen committed
68
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
Alp Dener's avatar
Alp Dener committed
69
70
                     pybind11::arg("execute") = EncapsulateFFI(GemmHandler));

71
  // Grouped GEMM
72
73
74
  dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
75
76
77
  dict["te_grouped_gemm_ffi"] =
      pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
                     pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
78

79
80
81
82
83
  // Amax
  dict["te_rht_amax_ffi"] = pybind11::dict(
      pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
      pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));

84
  return dict;
85
86
87
}

PYBIND11_MODULE(transformer_engine_jax, m) {
88
89
90
  m.def("registrations", &Registrations);
  m.def("get_fused_attn_backend", &GetFusedAttnBackend);
  m.def("get_cuda_version", &GetCudaRuntimeVersion);
91
  m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
92
  m.def("get_device_compute_capability", &GetDeviceComputeCapability);
93
  m.def("get_num_compute_streams", &nvte_get_num_compute_streams);
94
  m.def("get_cublasLt_version", &cublasLtGetVersion);
95
96
97
98
  m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
  m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
  m.def("get_norm_fwd_workspace_sizes", &GetNormForwardWorkspaceSizes);
  m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes);
99
100
  m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
  m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
101
  m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
Alp Dener's avatar
Alp Dener committed
102
  m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
Phuong Nguyen's avatar
Phuong Nguyen committed
103
104
  m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
  m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
105

106
107
108
109
110
111
112
113
  pybind11::enum_<DType>(m, "DType", pybind11::module_local())
      .value("kByte", DType::kByte)
      .value("kInt32", DType::kInt32)
      .value("kInt64", DType::kInt64)
      .value("kFloat32", DType::kFloat32)
      .value("kFloat16", DType::kFloat16)
      .value("kBFloat16", DType::kBFloat16)
      .value("kFloat8E4M3", DType::kFloat8E4M3)
114
115
116
      .value("kFloat8E5M2", DType::kFloat8E5M2)
      .value("kFloat8E8M0", DType::kFloat8E8M0)
      .value("kFloat4E2M1", DType::kFloat4E2M1);
117

118
119
120
121
  pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
      .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
      .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
      .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
122

123
124
125
126
  pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
      .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
      .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
      .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
127
128
129
130
      .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
      .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
      .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
             NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
131

132
133
134
  pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
      .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
      .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
135
136
137
138
139
140
141
142
143
      .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
      .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
      .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
      .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);

  pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
      .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
      .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
      .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
144

145
146
147
148
149
150
151
152
153
154
  pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
      .value("GELU", NVTE_Activation_Type::GELU)
      .value("GEGLU", NVTE_Activation_Type::GEGLU)
      .value("SILU", NVTE_Activation_Type::SILU)
      .value("SWIGLU", NVTE_Activation_Type::SWIGLU)
      .value("RELU", NVTE_Activation_Type::RELU)
      .value("REGLU", NVTE_Activation_Type::REGLU)
      .value("QGELU", NVTE_Activation_Type::QGELU)
      .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
      .value("SRELU", NVTE_Activation_Type::SRELU)
155
      .value("SREGLU", NVTE_Activation_Type::SREGLU)
156
      .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
157
      .export_values();
158

159
160
161
162
163
  pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
      .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
      .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
      .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
      .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
164
165
166
167
168
169

  pybind11::enum_<NVTE_Norm_Type>(m, "NVTE_Norm_Type", pybind11::module_local())
      .value("LayerNorm", NVTE_Norm_Type::LayerNorm)
      .value("RMSNorm", NVTE_Norm_Type::RMSNorm)
      .export_values();

170
171
172
173
  pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
      .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
      .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
      .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
174
      .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
175
176
      .value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
      .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
177
178
      .export_values();

179
180
181
182
183
  pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
                                                           pybind11::module_local())
      .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
      .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
      .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
184
      .export_values();
Phuong Nguyen's avatar
Phuong Nguyen committed
185
186
187
188
189
190

  pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
      .value("NONE", JAXX_Collective_Op::NONE)
      .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
      .value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
      .export_values();
191
192
193
194
}

}  // namespace jax
}  // namespace transformer_engine