/************************************************************************* * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include namespace transformer_engine { struct SimpleTensor { void *dptr; std::vector shape; DType dtype; SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : dptr(dptr), shape(shape), dtype(dtype) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} }; struct Tensor { SimpleTensor data; SimpleTensor amax; SimpleTensor scale; SimpleTensor scale_inv; Tensor() : data(), amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32) {} }; template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); } using byte = uint8_t; using int32 = int32_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; template struct TypeInfo{ using types = std::tuple; template struct Helper { constexpr static DType getType() { constexpr int i = static_cast(current); if (std::is_same::type>::value) { return current; } else { return Helper(i + 1)>::getType(); } } }; template struct Helper { constexpr static DType getType() { return DType::kNumTypes; } }; template constexpr static DType getType() { return Helper::getType(); } constexpr static DType dtype = getType(); constexpr static size_t size = sizeof(T); }; #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kByte: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kInt32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat8E5M2: \ { \ using type = fp8e5m2; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E4M3: \ { \ using type = fp8e4m3; \ {__VA_ARGS__} \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ case DType::kFloat32: \ { \ using type = float; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat16: \ { \ using type = fp16; \ {__VA_ARGS__} \ } \ break; \ case DType::kBFloat16: \ { \ using type = bf16; \ {__VA_ARGS__} \ } \ break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: \ { \ NVTE_ERROR("FP8 type not instantiated for input."); \ } \ break; \ default: \ NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ switch (dtype) \ { \ using namespace transformer_engine; \ case DType::kFloat16: \ { \ using type = fp16; \ __VA_ARGS__; \ break; \ } \ case DType::kBFloat16: \ { \ using type = bf16; \ __VA_ARGS__; \ break; \ } \ default: \ NVTE_ERROR("Invalid type for 16 bit."); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// inline size_t product(const std::vector &shape) { size_t ret = 1; for (const auto &elem : shape) { ret *= elem; } return ret; } inline int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } template struct is_fp8 : std::false_type {}; template <> struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; size_t typeToSize(const DType type); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); bool is_fp8_dtype(const DType t); } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_