extensions.cpp 5.58 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

9
10
#include <cublasLt.h>

11
#include "common/include/transformer_engine/fused_attn.h"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
#include "jax/csrc/utils.h"

namespace transformer_engine {
namespace jax {

template <typename T>
pybind11::capsule EncapsulateFunction(T *fn) {
    return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}

pybind11::dict Registrations() {
    pybind11::dict dict;
    dict["te_transpose"] = EncapsulateFunction(Transpose);
    dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
    dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
    dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
    dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
    dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
    dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
    dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
    dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
    dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward);
    dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8);
    dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward);
    dict["te_quantize"] = EncapsulateFunction(Quantize);
    dict["te_dequantize"] = EncapsulateFunction(Dequantize);
    dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward);
    dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward);
    dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward);
    dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward);
    dict["te_scaled_upper_triang_masked_softmax_forward"] =
        EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
    dict["te_scaled_upper_triang_masked_softmax_backward"] =
        EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
48
49
50
51
    dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward);
    dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward);
    dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward);
    dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward);
52
53
54
55
56
57
58
59
    return dict;
}

PYBIND11_MODULE(transformer_engine_jax, m) {
    m.def("registrations", &Registrations);
    m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
    m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
    m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
60
    m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor);
61
    m.def("get_fused_attn_backend", &GetFusedAttnBackend);
62
63
64
65
66
67
68
69
70
    m.def("get_cuda_version", &GetCudaRuntimeVersion);
    m.def("get_device_compute_capability", &GetDeviceComputeCapability);
    m.def("get_cublasLt_version", &cublasLtGetVersion);
    m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes);
    m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
    m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes);
    m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes);
    m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes);
    m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes);
71

72
    pybind11::enum_<DType>(m, "DType", pybind11::module_local())
73
74
        .value("kByte", DType::kByte)
        .value("kInt32", DType::kInt32)
75
        .value("kInt64", DType::kInt64)
76
77
78
79
80
        .value("kFloat32", DType::kFloat32)
        .value("kFloat16", DType::kFloat16)
        .value("kBFloat16", DType::kBFloat16)
        .value("kFloat8E4M3", DType::kFloat8E4M3)
        .value("kFloat8E5M2", DType::kFloat8E5M2);
81
82
83
84
85
86
87
88
89

    pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
        .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
        .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
        .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);

    pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
        .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
        .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
90
91
        .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
        .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
92
93

    pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
94
95
        .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
        .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD);
96
97
98
99
100
101

    pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
        .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
        .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
        .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
        .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
102
103
104
105
}

}  // namespace jax
}  // namespace transformer_engine