/************************************************************************* * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include "common/util/logging.h" #include "paddle/extension.h" #include "paddle/phi/backends/all_context.h" namespace transformer_engine { namespace paddle_ext { // Paddle Tensor Utils template inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { if (index < 0 || index >= x.numel()) { NVTE_ERROR("Index out of bound"); } return reinterpret_cast(x.data() + static_cast(index)); } template inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT if (index < 0 || index >= x.numel()) { NVTE_ERROR("Index out of bound"); } return reinterpret_cast(x.data() + static_cast(index)); } template inline const void *GetOptionalDataPtr(const paddle::optional &x, int64_t index) { return x ? GetDataPtr(*x, index) : nullptr; } template inline void *GetOptionalDataPtr(paddle::optional &x, int64_t index) { // NOLINT return x ? GetDataPtr(*x, index) : nullptr; } inline const void *GetOptionalDataPtr(const paddle::optional &x) { return x ? x->data() : nullptr; } inline void *GetOptionalDataPtr(paddle::optional &x) { // NOLINT return x ? x->data() : nullptr; } inline std::vector GetShapeArray(const paddle::Tensor &x) { std::vector shapes; for (auto dim : x.shape()) { shapes.push_back(static_cast(dim)); } return shapes; } inline std::vector GetShapeArray(const paddle::optional &x) { if (x) return GetShapeArray(x.get()); return {0}; } paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, bool init_to_zeros = 0); // DType Utils inline paddle::DataType Nvte2PaddleDType(DType t) { switch (t) { case DType::kInt32: case DType::kFloat32: return paddle::DataType::FLOAT32; case DType::kFloat16: return paddle::DataType::FLOAT16; case DType::kBFloat16: return paddle::DataType::BFLOAT16; case DType::kByte: case DType::kFloat8E4M3: case DType::kFloat8E5M2: return paddle::DataType::UINT8; default: NVTE_ERROR("Invalid type"); } } inline DType Paddle2NvteDType(paddle::DataType t) { switch (t) { case paddle::DataType::FLOAT16: return DType::kFloat16; case paddle::DataType::FLOAT32: return DType::kFloat32; case paddle::DataType::BFLOAT16: return DType::kBFloat16; case paddle::DataType::BOOL: return DType::kByte; case paddle::DataType::UINT8: return DType::kByte; case paddle::DataType::INT32: return DType::kInt32; case paddle::DataType::INT64: return DType::kInt64; default: NVTE_ERROR("Invalid type"); } } inline DType Int2NvteDType(int64_t dtype) { if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { return static_cast(dtype); } else { NVTE_ERROR("Type not supported."); } } // get the fused attention backend inline NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, -1, -1); return fused_attention_backend; } // CUDA Utils class cudaDevicePropertiesManager { public: static cudaDevicePropertiesManager &Instance() { static thread_local cudaDevicePropertiesManager instance; return instance; } int GetMultiProcessorCount() { if (!prop_queried_) { int device_id; NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); cudaGetDeviceProperties(&prop_, device_id); prop_queried_ = true; } return prop_.multiProcessorCount; } int GetMajor() { if (!prop_queried_) { int device_id; NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); cudaGetDeviceProperties(&prop_, device_id); prop_queried_ = true; } return prop_.major; } private: bool prop_queried_ = false; cudaDeviceProp prop_; }; // NVTE Tensor Utils TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, const DType type); TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type); TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, void *amax_ptr, void *scale_ptr, void *scale_inv_ptr); TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout); } // namespace paddle_ext } // namespace transformer_engine