Commit 8e0fd518 authored by wenjh's avatar wenjh
Browse files

Fix build problems while not support fp4


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent d86ee4c8
......@@ -66,7 +66,7 @@ enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'"
execute_process(COMMAND bash -c "python3 -c 'import torch; import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY)
......
......@@ -71,8 +71,12 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
#if FP4_TYPE_SUPPORTED
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
std::is_same_v<InputType, fp4e2m1>){
#else
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
#endif
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
......
......@@ -62,8 +62,12 @@ const std::string &typeName(DType type) {
{DType::kBFloat16, "bfloat16"},
{DType::kFloat8E4M3, "float8e4m3"},
{DType::kFloat8E5M2, "float8e5m2"},
{DType::kFloat8E8M0, "float8e8m0"},
{DType::kFloat4E2M1, "float4e2m1"}};
{DType::kFloat8E8M0, "float8e8m0"}
#if FP4_TYPE_SUPPORTED
,
{DType::kFloat4E2M1, "float4e2m1"}
#endif
};
return name_map.at(type);
}
......
......@@ -99,7 +99,7 @@ struct BitsNumber {
template <typename T>
struct TypeInfo {
#if FP4_TYPE_SUPPORTED
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8, fp4e2m1>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1, int8>;
#else
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, int8>;
#endif
......
......@@ -232,10 +232,12 @@ size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype) {
#if FP4_TYPE_SUPPORTED
if (buffer_dtype == DType::kFloat4E2M1) {
NVTE_CHECK(dim_last % 2 == 0,
"Last dimension of a tensor with FP4 type of data must be an even number!");
}
#endif
const size_t elements_num = dim_first * dim_last;
return get_buffer_size_bytes(elements_num, buffer_dtype);
}
......
......@@ -624,6 +624,7 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#if FP4_TYPE_SUPPORTED
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -649,6 +650,30 @@ struct TypeInfo {
default: \
NVTE_ERROR("Invalid type."); \
}
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
switch (dtype) { \
......
......@@ -14,6 +14,8 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#define TE_FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __cplusplus
extern "C" {
#endif
......@@ -32,7 +34,12 @@ enum NVTEDType {
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
#if TE_FP4_TYPE_SUPPORTED
kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTEInt8 = 11, /*!< 8-bit integer */
#else
kNVTEInt8 = 10, /*!< 8-bit integer */
#endif
kNVTENumTypes /*!< Number of supported types */
};
......@@ -411,8 +418,12 @@ enum class DType {
kFloat8E4M3 = 7,
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
#if TE_FP4_TYPE_SUPPORTED
kFloat4E2M1 = 10,
kInt8 = 11,
#else
kInt8 = 10,
#endif
kNumTypes
};
......@@ -439,7 +450,13 @@ inline bool is_fp8_dtype(const DType t) {
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
inline bool is_fp4_dtype(const DType t) {
#if TE_FP4_TYPE_SUPPORTED
return t == DType::kFloat4E2M1;
#else
return false;
#endif
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
......
......@@ -24,7 +24,9 @@ size_t typeToNumBits(const DType type) {
}
size_t typeToSize(const DType type) {
#if FP4_TYPE_SUPPORTED
NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
#endif
return typeToNumBits(type) / 8;
}
......@@ -44,8 +46,10 @@ std::string to_string(const DType type) {
return "Float8E5M2";
case DType::kFloat8E8M0:
return "Float8E8M0";
#if FP4_TYPE_SUPPORTED
case DType::kFloat4E2M1:
return "Float4E2M1";
#endif
case DType::kInt16:
return "Int16";
case DType::kInt32:
......
......@@ -318,8 +318,10 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kInt8:
return 8;
#if FP4_TYPE_SUPPORTED
case transformer_engine::DType::kFloat4E2M1:
return 4;
#endif
default:
NVTE_ERROR("Invalid type");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment