Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b3dcfc28
Commit
b3dcfc28
authored
Dec 03, 2025
by
wenjh
Browse files
Fix build error
parent
1e3c6a25
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
76 additions
and
44 deletions
+76
-44
tests/cpp/test_common.h
tests/cpp/test_common.h
+2
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+20
-18
transformer_engine/common/common.h
transformer_engine/common/common.h
+3
-1
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+2
-3
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+6
-20
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+11
-0
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+8
-0
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+2
-0
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+10
-0
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+12
-0
No files found.
tests/cpp/test_common.h
View file @
b3dcfc28
...
...
@@ -101,9 +101,9 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
fp4e2m1
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp8e8m0
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp8e8m0
>
;
#endif
template
<
typename
U
,
DType
current
>
...
...
transformer_engine/common/CMakeLists.txt
View file @
b3dcfc28
...
...
@@ -110,8 +110,9 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package
(
Python COMPONENTS Interpreter Development.Module REQUIRED
)
# NVIDIA MathDX include directory (from Python package install location)
if
(
NOT DEFINED MATHDX_INCLUDE_DIR
)
if
(
USE_CUDA
)
# NVIDIA MathDX include directory (from Python package install location)
if
(
NOT DEFINED MATHDX_INCLUDE_DIR
)
execute_process
(
COMMAND
${
Python_EXECUTABLE
}
-m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
...
...
@@ -127,9 +128,10 @@ if(NOT DEFINED MATHDX_INCLUDE_DIR)
endif
()
set
(
MATHDX_LOCATION
"
${
CMAKE_MATCH_1
}
"
)
set
(
MATHDX_INCLUDE_DIR
"
${
MATHDX_LOCATION
}
/nvidia/mathdx/include"
)
endif
()
if
(
NOT EXISTS
"
${
MATHDX_INCLUDE_DIR
}
"
)
endif
()
if
(
NOT EXISTS
"
${
MATHDX_INCLUDE_DIR
}
"
)
message
(
FATAL_ERROR
"MATHDX include directory not found at
${
MATHDX_INCLUDE_DIR
}
. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for
${
Python_EXECUTABLE
}
."
)
endif
()
endif
()
# Configure Transformer Engine library
...
...
transformer_engine/common/common.h
View file @
b3dcfc28
...
...
@@ -417,11 +417,13 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
b3dcfc28
...
...
@@ -1175,9 +1175,6 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const
Tensor
*
inputCounter
=
convertNVTETensor
(
counter
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
const
void
*
alpha_ptr
=
GetScalarOne
();
const
void
*
beta_ptr
=
accumulate
?
GetScalarOne
()
:
GetScalarZero
();
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
"Atomic GEMM only supports delayed scaling."
);
...
...
@@ -1230,6 +1227,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
stream
);
}
#else
const
void
*
alpha_ptr
=
GetScalarOne
();
const
void
*
beta_ptr
=
accumulate
?
GetScalarOne
()
:
GetScalarZero
();
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha_ptr
,
beta_ptr
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
b3dcfc28
...
...
@@ -14,8 +14,6 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#define TE_FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __cplusplus
extern
"C"
{
#endif
...
...
@@ -33,13 +31,9 @@ enum NVTEDType {
kNVTEBFloat16
=
6
,
/*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
#if TE_FP4_TYPE_SUPPORTED
kNVTEFloat4E2M1
=
10
,
/*!< 4-bit float (E2M1) */
kNVTEInt8
=
11
,
/*!< 8-bit integer */
#else
kNVTEInt8
=
10
,
/*!< 8-bit integer */
#endif
kNVTEInt8
=
9
,
/*!< 8-bit integer */
kNVTEFloat8E8M0
=
10
,
/*!< 8-bit float (E8M0) */
kNVTEFloat4E2M1
=
11
,
/*!< 4-bit float (E2M1) */
kNVTENumTypes
/*!< Number of supported types */
};
...
...
@@ -423,13 +417,9 @@ enum class DType {
kBFloat16
=
6
,
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
#if TE_FP4_TYPE_SUPPORTED
kFloat4E2M1
=
10
,
kInt8
=
11
,
#else
kInt8
=
10
,
#endif
kInt8
=
9
,
kFloat8E8M0
=
10
,
kFloat4E2M1
=
11
,
kNumTypes
};
...
...
@@ -457,11 +447,7 @@ inline bool is_fp8_dtype(const DType t) {
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
#if TE_FP4_TYPE_SUPPORTED
return
t
==
DType
::
kFloat4E2M1
;
#else
return
false
;
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
b3dcfc28
...
...
@@ -5,13 +5,21 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#else
#define CUDA_VERSION 0
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif
#include <utility>
#include "common/common.h"
...
...
@@ -19,7 +27,10 @@
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "curanddx.hpp"
#endif
namespace
transformer_engine
{
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
b3dcfc28
...
...
@@ -576,6 +576,7 @@ __device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float b
#define DIRECT_SCALING_FACTORS_STORE 1
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
...
...
@@ -1065,6 +1066,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
}
// namespace nvfp4_kernel
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
...
...
@@ -1725,6 +1727,11 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
nvfp4_quantize
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
using
namespace
nvfp4_kernel
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
...
...
@@ -1853,6 +1860,7 @@ void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cud
break
;
});
// NOLINT(*)
);
// NOLINT(*)
#endif
}
namespace
detail
{
...
...
transformer_engine/common/util/logging.h
View file @
b3dcfc28
...
...
@@ -23,7 +23,9 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#ifndef __HIP_PLATFORM_AMD__
#include "nccl.h"
#endif
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
View file @
b3dcfc28
...
...
@@ -12,7 +12,13 @@
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#else
#define CUDA_VERSION 0
#endif
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
...
...
@@ -23,7 +29,11 @@
#include "../common.h"
#include "../utils.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "curanddx.hpp"
#endif
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
b3dcfc28
...
...
@@ -1486,10 +1486,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_amax"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_amax
(
input
.
data
(),
out
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
#endif
}
else
{
// raise error since it's not supported yet
NVTE_CHECK
(
false
,
"Pre-RHT amax is not supported yet"
);
...
...
@@ -1612,11 +1616,15 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
rht_output_t_cpp
.
set_rowwise_data
(
rht_output_t
.
data_ptr
(),
input
.
dtype
(),
std
::
vector
<
size_t
>
{
cols
,
rows
});
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform
(
input
.
data
(),
rht_output_t_cpp
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
#endif
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
...
...
@@ -1628,10 +1636,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
NVTE_CHECK
(
this
->
rht_matrix
.
defined
()
&&
this
->
rht_matrix
.
numel
()
>
0
,
"RHT matrix is not set"
);
auto
rht_matrix_nvte
=
makeTransformerEngineTensor
(
this
->
rht_matrix
);
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_cast_fusion_columnwise"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_cast_fusion_columnwise
(
input
.
data
(),
out_transpose
.
data
(),
rht_matrix_nvte
.
data
(),
quant_config
,
stream
);
});
#endif
}
}
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment