extensions.cu 3.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"

namespace transformer_engine {
namespace paddle_ext {

size_t get_cublasLt_version() { return cublasLtGetVersion(); }

PYBIND11_MODULE(transformer_engine_paddle, m) {
    // Misc
    m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
17
    m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
18
    m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string");
19
20
21
22
23
24
25
26
27
    // Data structures
    py::enum_<DType>(m, "DType", py::module_local())
        .value("kByte", DType::kByte)
        .value("kInt32", DType::kInt32)
        .value("kFloat32", DType::kFloat32)
        .value("kFloat16", DType::kFloat16)
        .value("kBFloat16", DType::kBFloat16)
        .value("kFloat8E4M3", DType::kFloat8E4M3)
        .value("kFloat8E5M2", DType::kFloat8E5M2);
28
29
30
31
32
33
34
35
36
37
38
39

    py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
        .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
        .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
        .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);

    py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
        .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
        .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
        .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);

    py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD)
        .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D)
        .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
        .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D)
        .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD)
        .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
        .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D)
        .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
        .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D)
        .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
        .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
        .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D)
        .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
        .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D)
        .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
55
56
57
58
59
60

    py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", py::module_local())
        .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
        .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
        .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
        .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
61
62
63
}
}  // namespace paddle_ext
}  // namespace transformer_engine