Commit 8e0fd518 authored by wenjh's avatar wenjh
Browse files

Fix build problems while not support fp4


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent d86ee4c8
...@@ -66,7 +66,7 @@ enable_testing() ...@@ -66,7 +66,7 @@ enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH) if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" execute_process(COMMAND bash -c "python3 -c 'import torch; import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE) OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY)
......
...@@ -71,8 +71,12 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const ...@@ -71,8 +71,12 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends. // Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
#if FP4_TYPE_SUPPORTED
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> || if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){ std::is_same_v<InputType, fp4e2m1>){
#else
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
#endif
compute_t g = static_cast<compute_t>(gamma); compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) { if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f); g += static_cast<compute_t>(1.f);
......
...@@ -62,8 +62,12 @@ const std::string &typeName(DType type) { ...@@ -62,8 +62,12 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"}, {DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"}, {DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"}, {DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"}, {DType::kFloat8E8M0, "float8e8m0"}
{DType::kFloat4E2M1, "float4e2m1"}}; #if FP4_TYPE_SUPPORTED
,
{DType::kFloat4E2M1, "float4e2m1"}
#endif
};
return name_map.at(type); return name_map.at(type);
} }
......
...@@ -99,7 +99,7 @@ struct BitsNumber { ...@@ -99,7 +99,7 @@ struct BitsNumber {
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
#if FP4_TYPE_SUPPORTED #if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8, fp4e2m1>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1, int8>;
#else #else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>;
#endif #endif
......
...@@ -232,10 +232,12 @@ size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype ...@@ -232,10 +232,12 @@ size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last, size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype) { const DType buffer_dtype) {
#if FP4_TYPE_SUPPORTED
if (buffer_dtype == DType::kFloat4E2M1) { if (buffer_dtype == DType::kFloat4E2M1) {
NVTE_CHECK(dim_last % 2 == 0, NVTE_CHECK(dim_last % 2 == 0,
"Last dimension of a tensor with FP4 type of data must be an even number!"); "Last dimension of a tensor with FP4 type of data must be an even number!");
} }
#endif
const size_t elements_num = dim_first * dim_last; const size_t elements_num = dim_first * dim_last;
return get_buffer_size_bytes(elements_num, buffer_dtype); return get_buffer_size_bytes(elements_num, buffer_dtype);
} }
......
...@@ -624,6 +624,7 @@ struct TypeInfo { ...@@ -624,6 +624,7 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#if FP4_TYPE_SUPPORTED
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -649,6 +650,30 @@ struct TypeInfo { ...@@ -649,6 +650,30 @@ struct TypeInfo {
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <stddef.h> #include <stddef.h>
#define TE_FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
...@@ -32,7 +34,12 @@ enum NVTEDType { ...@@ -32,7 +34,12 @@ enum NVTEDType {
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
#if TE_FP4_TYPE_SUPPORTED
kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */ kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTEInt8 = 11, /*!< 8-bit integer */
#else
kNVTEInt8 = 10, /*!< 8-bit integer */
#endif
kNVTENumTypes /*!< Number of supported types */ kNVTENumTypes /*!< Number of supported types */
}; };
...@@ -411,8 +418,12 @@ enum class DType { ...@@ -411,8 +418,12 @@ enum class DType {
kFloat8E4M3 = 7, kFloat8E4M3 = 7,
kFloat8E5M2 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9, kFloat8E8M0 = 9,
#if TE_FP4_TYPE_SUPPORTED
kFloat4E2M1 = 10, kFloat4E2M1 = 10,
kInt8 = 11, kInt8 = 11,
#else
kInt8 = 10,
#endif
kNumTypes kNumTypes
}; };
...@@ -439,7 +450,13 @@ inline bool is_fp8_dtype(const DType t) { ...@@ -439,7 +450,13 @@ inline bool is_fp8_dtype(const DType t) {
* Return true if TE datatype is FP4 * Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest * \param[in] DType TE Datatype of interest
*/ */
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } inline bool is_fp4_dtype(const DType t) {
#if TE_FP4_TYPE_SUPPORTED
return t == DType::kFloat4E2M1;
#else
return false;
#endif
}
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
......
...@@ -24,7 +24,9 @@ size_t typeToNumBits(const DType type) { ...@@ -24,7 +24,9 @@ size_t typeToNumBits(const DType type) {
} }
size_t typeToSize(const DType type) { size_t typeToSize(const DType type) {
#if FP4_TYPE_SUPPORTED
NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type."); NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
#endif
return typeToNumBits(type) / 8; return typeToNumBits(type) / 8;
} }
...@@ -44,8 +46,10 @@ std::string to_string(const DType type) { ...@@ -44,8 +46,10 @@ std::string to_string(const DType type) {
return "Float8E5M2"; return "Float8E5M2";
case DType::kFloat8E8M0: case DType::kFloat8E8M0:
return "Float8E8M0"; return "Float8E8M0";
#if FP4_TYPE_SUPPORTED
case DType::kFloat4E2M1: case DType::kFloat4E2M1:
return "Float4E2M1"; return "Float4E2M1";
#endif
case DType::kInt16: case DType::kInt16:
return "Int16"; return "Int16";
case DType::kInt32: case DType::kInt32:
......
...@@ -318,8 +318,10 @@ inline size_t typeToNumBits(transformer_engine::DType t) { ...@@ -318,8 +318,10 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case transformer_engine::DType::kFloat8E5M2: case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kInt8: case transformer_engine::DType::kInt8:
return 8; return 8;
#if FP4_TYPE_SUPPORTED
case transformer_engine::DType::kFloat4E2M1: case transformer_engine::DType::kFloat4E2M1:
return 4; return 4;
#endif
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment