Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8e0fd518
Commit
8e0fd518
authored
Sep 02, 2025
by
wenjh
Browse files
Fix build problems while not support fp4
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
d86ee4c8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
63 additions
and
5 deletions
+63
-5
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+1
-1
tests/cpp/operator/test_normalization.h
tests/cpp/operator/test_normalization.h
+4
-0
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+6
-2
tests/cpp/test_common.h
tests/cpp/test_common.h
+1
-1
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+2
-0
transformer_engine/common/common.h
transformer_engine/common/common.h
+25
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+18
-1
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+4
-0
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+2
-0
No files found.
tests/cpp/CMakeLists.txt
View file @
8e0fd518
...
...
@@ -66,7 +66,7 @@ enable_testing()
include_directories
(
${
gtest_SOURCE_DIR
}
/include
${
gtest_SOURCE_DIR
}
)
if
(
NOT DEFINED TE_LIB_PATH
)
execute_process
(
COMMAND bash -c
"python3 -c 'import transformer_engine as te; print(te.__file__)'"
execute_process
(
COMMAND bash -c
"python3 -c 'import
torch; import
transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
get_filename_component
(
TE_LIB_PATH
${
TE_LIB_FILE
}
DIRECTORY
)
...
...
tests/cpp/operator/test_normalization.h
View file @
8e0fd518
...
...
@@ -71,8 +71,12 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
// Remove the use_cudnn check here when it is supported by both backends.
const
bool
zero_centered_gamma_in_weight_dtype
=
use_cudnn
&&
cudnn_zero_centered_gamma_in_weight_dtype
;
#if FP4_TYPE_SUPPORTED
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
||
std
::
is_same_v
<
InputType
,
fp4e2m1
>
){
#else
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
#endif
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
...
...
tests/cpp/test_common.cu
View file @
8e0fd518
...
...
@@ -62,8 +62,12 @@ const std::string &typeName(DType type) {
{
DType
::
kBFloat16
,
"bfloat16"
},
{
DType
::
kFloat8E4M3
,
"float8e4m3"
},
{
DType
::
kFloat8E5M2
,
"float8e5m2"
},
{
DType
::
kFloat8E8M0
,
"float8e8m0"
},
{
DType
::
kFloat4E2M1
,
"float4e2m1"
}};
{
DType
::
kFloat8E8M0
,
"float8e8m0"
}
#if FP4_TYPE_SUPPORTED
,
{
DType
::
kFloat4E2M1
,
"float4e2m1"
}
#endif
};
return
name_map
.
at
(
type
);
}
...
...
tests/cpp/test_common.h
View file @
8e0fd518
...
...
@@ -99,7 +99,7 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
,
fp4e2m1
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
fp4e2m1
,
int8
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
#endif
...
...
transformer_engine/common/common.cu
View file @
8e0fd518
...
...
@@ -232,10 +232,12 @@ size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype
size_t
get_buffer_size_bytes
(
const
size_t
dim_first
,
const
size_t
dim_last
,
const
DType
buffer_dtype
)
{
#if FP4_TYPE_SUPPORTED
if
(
buffer_dtype
==
DType
::
kFloat4E2M1
)
{
NVTE_CHECK
(
dim_last
%
2
==
0
,
"Last dimension of a tensor with FP4 type of data must be an even number!"
);
}
#endif
const
size_t
elements_num
=
dim_first
*
dim_last
;
return
get_buffer_size_bytes
(
elements_num
,
buffer_dtype
);
}
...
...
transformer_engine/common/common.h
View file @
8e0fd518
...
...
@@ -624,6 +624,7 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#if FP4_TYPE_SUPPORTED
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
@@ -649,6 +650,30 @@ struct TypeInfo {
default: \
NVTE_ERROR("Invalid type."); \
}
#else
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: \
case DType::kFloat8E4M3: { \
NVTE_ERROR("FP8 type not instantiated for input."); \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#endif
#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
switch (dtype) { \
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
8e0fd518
...
...
@@ -14,6 +14,8 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#define TE_FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __cplusplus
extern
"C"
{
#endif
...
...
@@ -32,7 +34,12 @@ enum NVTEDType {
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
#if TE_FP4_TYPE_SUPPORTED
kNVTEFloat4E2M1
=
10
,
/*!< 4-bit float (E2M1) */
kNVTEInt8
=
11
,
/*!< 8-bit integer */
#else
kNVTEInt8
=
10
,
/*!< 8-bit integer */
#endif
kNVTENumTypes
/*!< Number of supported types */
};
...
...
@@ -411,8 +418,12 @@ enum class DType {
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
#if TE_FP4_TYPE_SUPPORTED
kFloat4E2M1
=
10
,
kInt8
=
11
,
#else
kInt8
=
10
,
#endif
kNumTypes
};
...
...
@@ -439,7 +450,13 @@ inline bool is_fp8_dtype(const DType t) {
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat4E2M1
;
}
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
#if TE_FP4_TYPE_SUPPORTED
return
t
==
DType
::
kFloat4E2M1
;
#else
return
false
;
#endif
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
...
...
transformer_engine/common/transformer_engine.cpp
View file @
8e0fd518
...
...
@@ -24,7 +24,9 @@ size_t typeToNumBits(const DType type) {
}
size_t
typeToSize
(
const
DType
type
)
{
#if FP4_TYPE_SUPPORTED
NVTE_CHECK
(
type
!=
DType
::
kFloat4E2M1
,
"typeToSize() Does not support FP4 data type."
);
#endif
return
typeToNumBits
(
type
)
/
8
;
}
...
...
@@ -44,8 +46,10 @@ std::string to_string(const DType type) {
return
"Float8E5M2"
;
case
DType
::
kFloat8E8M0
:
return
"Float8E8M0"
;
#if FP4_TYPE_SUPPORTED
case
DType
::
kFloat4E2M1
:
return
"Float4E2M1"
;
#endif
case
DType
::
kInt16
:
return
"Int16"
;
case
DType
::
kInt32
:
...
...
transformer_engine/pytorch/csrc/common.h
View file @
8e0fd518
...
...
@@ -318,8 +318,10 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case
transformer_engine
::
DType
::
kFloat8E5M2
:
case
transformer_engine
::
DType
::
kInt8
:
return
8
;
#if FP4_TYPE_SUPPORTED
case
transformer_engine
::
DType
::
kFloat4E2M1
:
return
4
;
#endif
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment