Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
962b976d
"vscode:/vscode.git/clone" did not exist on "65bed07e38a0d291eb37dad29ef0bbd4570e769f"
Commit
962b976d
authored
Oct 15, 2025
by
yuguo
Browse files
[DCU] fix release 2.8 compile issues
parent
688c7ab9
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
50 additions
and
80 deletions
+50
-80
tests/cpp/test_common.h
tests/cpp/test_common.h
+2
-2
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+0
-2
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+3
-3
transformer_engine/common/common.h
transformer_engine/common/common.h
+3
-1
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+7
-4
transformer_engine/common/hadamard_transform/hadamard_transform.cu
...er_engine/common/hadamard_transform/hadamard_transform.cu
+0
-2
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+0
-2
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+6
-20
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+6
-0
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+6
-1
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+4
-0
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+1
-1
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+0
-41
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+12
-1
No files found.
tests/cpp/test_common.h
View file @
962b976d
...
...
@@ -101,9 +101,9 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
fp4e2m1
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp8e8m0
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
fp8e8m0
,
int8
>
;
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp8e8m0
>
;
#endif
template
<
typename
U
,
DType
current
>
...
...
transformer_engine/common/CMakeLists.txt
View file @
962b976d
...
...
@@ -266,8 +266,6 @@ else()
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
962b976d
...
...
@@ -343,6 +343,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
initialize
(
buffer_shape
,
buffer_dtype
,
rs_overlap_first_gemm
);
}
...
...
@@ -352,7 +353,6 @@ void CommOverlapBase::initialize(const std::vector<size_t> &buffer_shape, DType
if
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
){
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
...
...
@@ -776,6 +776,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
initialize
(
buffer_shape
,
buffer_dtype
,
comm_type
,
aggregate
);
}
...
...
@@ -785,7 +786,6 @@ void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DTy
if
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
){
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_aggregate
=
aggregate
;
...
...
@@ -839,7 +839,7 @@ void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DTy
static
cudaStream_t
send_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
recv_stream
;
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max
_streams
,
_tp_size
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
_ub
_stream
_num
s
,
_tp_size
);
i
++
)
{
if
(
send_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
send_streams
[
i
],
cudaStreamNonBlocking
,
_comm_priority
));
}
...
...
transformer_engine/common/common.h
View file @
962b976d
...
...
@@ -417,11 +417,13 @@ struct BitsNumber {
template
<
typename
T
>
struct
TypeInfo
{
#if FP4_TYPE_SUPPORTED
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
,
fp4e2m1
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
#if CUDA_VERSION >= 12080
,
fp8e8m0
#endif
,
fp4e2m1
>
;
#else
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
fp16
,
bf16
,
fp8e4m3
,
fp8e5m2
,
int8
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
962b976d
...
...
@@ -990,7 +990,7 @@ void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTET
const
bool
use_fp8
=
is_fp8_dtype
(
A_tensor
->
data
.
dtype
)
||
is_fp8_dtype
(
B_tensor
->
data
.
dtype
);
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
config_
.
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
((
epilogue_bias_tensor
->
data
.
dptr
!=
nullptr
)
||
(
epilogue_aux_tensor
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
cublas_gemm
(
A_tensor
,
B_tensor
,
D_tensor
,
epilogue_bias_tensor
,
epilogue_aux_tensor
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
with_grad_epilogue
,
workspace_ptr
,
workspace_size
,
accumulate
,
config_
.
use_split_accumulator
,
config_
.
sm_count
,
0
,
0
,
...
...
@@ -1015,11 +1015,13 @@ void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTET
nullptr
,
stream
);
}
#else
// Launch GEMM
cublas_gemm
(
A_tensor
,
B_tensor
,
D_tensor
,
epilogue_bias_tensor
,
epilogue_aux_tensor
,
transa
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transb
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
with_grad_epilogue
,
workspace_ptr
,
workspace_size
,
alpha
,
beta
,
config_
.
use_split_accumulator
,
config_
.
sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
@@ -1042,7 +1044,7 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
if
(
is_nvfp_scaling
(
inputA
->
scaling_mode
)
||
is_nvfp_scaling
(
inputB
->
scaling_mode
))
{
NVTE_ERROR
(
"nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."
);
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
alpha
==
1.0
f
,
"alpha must be 1.0 for hip"
);
NVTE_CHECK
(
beta
==
1.0
f
||
beta
==
0.0
f
,
"beta must be 1.0 or 0.0 for hip"
);
bool
accumulate
=
false
;
...
...
@@ -1111,6 +1113,7 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
&
alpha
,
&
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
@@ -1156,8 +1159,6 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const
Tensor
*
inputCounter
=
convertNVTETensor
(
counter
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
const
void
*
alpha_ptr
=
GetScalarOne
();
const
void
*
beta_ptr
=
accumulate
?
GetScalarOne
()
:
GetScalarZero
();
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
...
...
@@ -1211,6 +1212,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
stream
);
}
#else
const
void
*
alpha_ptr
=
GetScalarOne
();
const
void
*
beta_ptr
=
accumulate
?
GetScalarOne
()
:
GetScalarZero
();
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha_ptr
,
beta_ptr
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
...
...
transformer_engine/common/hadamard_transform/hadamard_transform.cu
View file @
962b976d
...
...
@@ -5,9 +5,7 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
...
...
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
View file @
962b976d
...
...
@@ -5,9 +5,7 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
962b976d
...
...
@@ -14,8 +14,6 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#define TE_FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __cplusplus
extern
"C"
{
#endif
...
...
@@ -33,13 +31,9 @@ enum NVTEDType {
kNVTEBFloat16
=
6
,
/*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3
=
7
,
/*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2
=
8
,
/*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0
=
9
,
/*!< 8-bit float (E8M0) */
#if TE_FP4_TYPE_SUPPORTED
kNVTEFloat4E2M1
=
10
,
/*!< 4-bit float (E2M1) */
kNVTEInt8
=
11
,
/*!< 8-bit integer */
#else
kNVTEInt8
=
10
,
/*!< 8-bit integer */
#endif
kNVTEInt8
=
9
,
/*!< 8-bit integer */
kNVTEFloat8E8M0
=
10
,
/*!< 8-bit float (E8M0) */
kNVTEFloat4E2M1
=
11
,
/*!< 4-bit float (E2M1) */
kNVTENumTypes
/*!< Number of supported types */
};
...
...
@@ -423,13 +417,9 @@ enum class DType {
kBFloat16
=
6
,
kFloat8E4M3
=
7
,
kFloat8E5M2
=
8
,
kFloat8E8M0
=
9
,
#if TE_FP4_TYPE_SUPPORTED
kFloat4E2M1
=
10
,
kInt8
=
11
,
#else
kInt8
=
10
,
#endif
kInt8
=
9
,
kFloat8E8M0
=
10
,
kFloat4E2M1
=
11
,
kNumTypes
};
...
...
@@ -457,11 +447,7 @@ inline bool is_fp8_dtype(const DType t) {
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
#if TE_FP4_TYPE_SUPPORTED
return
t
==
DType
::
kFloat4E2M1
;
#else
return
false
;
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
962b976d
...
...
@@ -7,13 +7,17 @@
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#else
#define CUDA_VERSION 0
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif
#include <utility>
#include "common/common.h"
...
...
@@ -21,7 +25,9 @@
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "curanddx.hpp"
#endif
namespace
transformer_engine
{
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
962b976d
...
...
@@ -575,7 +575,7 @@ __device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float b
}
#define DIRECT_SCALING_FACTORS_STORE 1
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
...
...
@@ -1065,6 +1065,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif
}
// namespace nvfp4_kernel
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
...
...
@@ -1725,6 +1726,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
nvfp4_quantize
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
using
namespace
nvfp4_kernel
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
...
...
@@ -1853,6 +1857,7 @@ void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cud
break
;
});
// NOLINT(*)
);
// NOLINT(*)
#endif
}
namespace
detail
{
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
View file @
962b976d
...
...
@@ -14,6 +14,8 @@
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#else
#define CUDA_VERSION 0
#endif
#include <cuda_runtime.h>
...
...
@@ -25,7 +27,9 @@
#include "../common.h"
#include "../utils.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "curanddx.hpp"
#endif
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
...
...
transformer_engine/common/util/pybind_helper.h
View file @
962b976d
...
...
@@ -27,7 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1)
;
\
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1)
\
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
...
...
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
962b976d
...
...
@@ -99,47 +99,6 @@ std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
return
{
std
::
move
(
te_T
),
std
::
move
(
py_T
)};
}
// helper function for S and dP quantizers
std
::
pair
<
TensorWrapper
,
py
::
object
>
quantizer_helper
(
py
::
handle
quantizer
,
const
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
bool
create_hp_tensor_for_cs
,
std
::
optional
<
at
::
Tensor
>
data
)
{
std
::
unique_ptr
<
Quantizer
>
T_quantizer
=
convert_quantizer
(
quantizer
);
TensorWrapper
te_T
;
py
::
object
py_T
;
if
(
quantizer
.
is_none
())
{
// high precision
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
T_quantizer
.
get
());
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
none_quantizer
->
create_tensor
(
shape
,
dtype
);
}
}
else
if
(
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// delayed scaling; this helps initialize scale_inv
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
T_quantizer
.
get
());
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
,
data
,
std
::
nullopt
,
std
::
nullopt
);
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// current scaling
auto
*
T_quantizer_fp8
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
T_quantizer
.
get
());
if
(
create_hp_tensor_for_cs
)
{
if
(
data
.
has_value
())
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
,
data
.
value
());
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_unquantized_tensor_with_amax
(
shape
,
dtype
);
}
}
else
{
std
::
tie
(
te_T
,
py_T
)
=
T_quantizer_fp8
->
create_tensor
(
shape
,
dtype
);
NVTE_CHECK
(
!
data
.
has_value
(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"
);
}
}
return
{
std
::
move
(
te_T
),
std
::
move
(
py_T
)};
}
// fused attention FWD with separate Q, K and V tensors
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
962b976d
...
...
@@ -1486,10 +1486,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_amax"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_amax
(
input
.
data
(),
out
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
#endif
}
else
{
// raise error since it's not supported yet
NVTE_CHECK
(
false
,
"Pre-RHT amax is not supported yet"
);
...
...
@@ -1611,12 +1615,15 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// result of transposed RHT to the output of rowwise.
rht_output_t_cpp
.
set_rowwise_data
(
rht_output_t
.
data_ptr
(),
input
.
dtype
(),
std
::
vector
<
size_t
>
{
cols
,
rows
});
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform
(
input
.
data
(),
rht_output_t_cpp
.
data
(),
0
,
this
->
rht_matrix_random_sign_mask_t
,
stream
);
});
#endif
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
...
...
@@ -1628,10 +1635,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
NVTE_CHECK
(
this
->
rht_matrix
.
defined
()
&&
this
->
rht_matrix
.
numel
()
>
0
,
"RHT matrix is not set"
);
auto
rht_matrix_nvte
=
makeTransformerEngineTensor
(
this
->
rht_matrix
);
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
false
,
"Not only supported for nvte_hadamard_transform_cast_fusion_columnwise"
);
#else
NVTE_SCOPED_GIL_RELEASE
({
nvte_hadamard_transform_cast_fusion_columnwise
(
input
.
data
(),
out_transpose
.
data
(),
rht_matrix_nvte
.
data
(),
quant_config
,
stream
);
});
#endif
}
}
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment