/************************************************************************* * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include namespace transformer_engine { struct Tensor { void* dptr; std::vector shape; DType dtype; Tensor() : dptr(nullptr), shape(), dtype(DType::kFloat32) {} }; template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); } using byte = uint8_t; using int32 = int32_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; template struct TypeInfo{ using types = std::tuple; template struct Helper { constexpr static DType getType() { constexpr int i = static_cast(current); if (std::is_same::type>::value) { return current; } else { return Helper(i + 1)>::getType(); } } }; template struct Helper { constexpr static DType getType() { return DType::kNumTypes; } }; template constexpr static DType getType() { return Helper::getType(); } constexpr static DType dtype = getType(); constexpr static size_t size = sizeof(T); }; #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kByte: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kInt32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: \ { \ NVTE_ERROR("FP8 type not instantiated for input."); \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } template struct TypeId{}; template<> struct TypeId{ constexpr static uint32_t Value = 0; }; template<> struct TypeId{ constexpr static uint32_t Value = 1; }; template<> struct TypeId{ constexpr static uint32_t Value = 2; }; template<> struct TypeId{ constexpr static uint32_t Value = 3; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Type2Key{ constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct WeightType2Key : public Type2Key{}; template struct InputType2Key : public Type2Key{}; template struct OutputType2Key : public Type2Key{}; template struct ComputeType2Key : public Type2Key{}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Types2Key{ constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; constexpr static inline uint64_t get(const uint64_t hidden_size){ constexpr uint64_t type_key = Value; return (type_key << 32) | hidden_size; } }; inline size_t product(const std::vector &shape) { size_t ret = 1; for (const auto &elem : shape) { ret *= elem; } return ret; } template struct is_fp8 : std::false_type {}; template <> struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; size_t typeToSize(const DType type); } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_