/************************************************************************* * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "common.h" namespace transformer_engine { namespace paddle_ext { size_t get_cublasLt_version() { return cublasLtGetVersion(); } PYBIND11_MODULE(transformer_engine_paddle, m) { // Misc m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); // Data structures py::enum_(m, "DType", py::module_local()) .value("kByte", DType::kByte) .value("kInt32", DType::kInt32) .value("kFloat32", DType::kFloat32) .value("kFloat16", DType::kFloat16) .value("kBFloat16", DType::kBFloat16) .value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E5M2", DType::kFloat8E5M2); py::enum_(m, "NVTE_Bias_Type") .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); py::enum_(m, "NVTE_Mask_Type") .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); py::enum_(m, "NVTE_QKV_Layout") .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); } } // namespace paddle_ext } // namespace transformer_engine