Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1819 additions
and
324 deletions
+1819
-324
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+233
-131
transformer_engine/common/include/transformer_engine/cast.h
transformer_engine/common/include/transformer_engine/cast.h
+61
-14
transformer_engine/common/include/transformer_engine/fused_rope.h
...mer_engine/common/include/transformer_engine/fused_rope.h
+30
-78
transformer_engine/common/include/transformer_engine/normalization.h
..._engine/common/include/transformer_engine/normalization.h
+10
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+53
-10
transformer_engine/common/libtransformer_engine.version
transformer_engine/common/libtransformer_engine.version
+3
-1
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+31
-4
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+18
-13
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+12
-11
transformer_engine/common/nvshmem_api/CMakeLists.txt
transformer_engine/common/nvshmem_api/CMakeLists.txt
+27
-0
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
+51
-0
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
+38
-0
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+2
-2
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+101
-0
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+2
-1
transformer_engine/common/recipe/recipe_common.cuh
transformer_engine/common/recipe/recipe_common.cuh
+27
-10
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+30
-49
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+36
-0
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+561
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+493
-0
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
ab3e5a92
...
@@ -63,92 +63,170 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) {
...
@@ -63,92 +63,170 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS
(
cublasLtCreate
(
handle
));
NVTE_CHECK_CUBLAS
(
cublasLtCreate
(
handle
));
}
}
/* Parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
struct
GemmParam
{
struct
GemmParam
{
void
*
A
;
void
*
A
=
nullptr
;
void
*
B
;
void
*
B
=
nullptr
;
cublasOperation_t
transA
;
cublasOperation_t
transA
=
CUBLAS_OP_N
;
cublasOperation_t
transB
;
cublasOperation_t
transB
=
CUBLAS_OP_N
;
transformer_engine
::
DType
Atype
;
transformer_engine
::
DType
Atype
=
transformer_engine
::
DType
::
kNumTypes
;
transformer_engine
::
DType
Btype
;
transformer_engine
::
DType
Btype
=
transformer_engine
::
DType
::
kNumTypes
;
void
*
A_scale_inv
;
void
*
A_scale_inv
=
nullptr
;
void
*
B_scale_inv
;
void
*
B_scale_inv
=
nullptr
;
int
lda
;
int
lda
=
0
;
// A column strides
int
ldb
;
int
ldb
=
0
;
// B column strides
GemmParam
(
cublasOperation_t
transA
,
cublasOperation_t
transB
)
:
A
(
nullptr
),
B
(
nullptr
),
transA
(
transA
),
transB
(
transB
),
Atype
(
transformer_engine
::
DType
::
kNumTypes
),
Btype
(
transformer_engine
::
DType
::
kNumTypes
),
A_scale_inv
(
nullptr
),
B_scale_inv
(
nullptr
),
lda
(
0
),
ldb
(
0
)
{}
};
};
/* Populate parameters for cuBLAS GEMM
*
* cuBLAS follows the BLAS convention of column-major ordering. This
* is different than the row-major that is typically used in
* Transformer Engine.
*
*/
GemmParam
CanonicalizeGemmInput
(
const
transformer_engine
::
Tensor
&
A
,
const
cublasOperation_t
transA
,
GemmParam
CanonicalizeGemmInput
(
const
transformer_engine
::
Tensor
&
A
,
const
cublasOperation_t
transA
,
const
transformer_engine
::
Tensor
&
B
,
const
cublasOperation_t
transB
,
const
transformer_engine
::
Tensor
&
B
,
const
cublasOperation_t
transB
,
const
int
k
,
const
int
lda
,
const
int
ldb
)
{
int
m
,
int
n
,
int
k
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
NVTE_CHECK
(
A
.
scaling_mode
==
B
.
scaling_mode
,
NVTE_CHECK
(
"Inputs A and B to GEMM need to have the same scaling mode!"
);
A
.
scaling_mode
==
B
.
scaling_mode
||
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
||
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
&&
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
),
"Inputs A and B to GEMM need to have compatible scaling modes!"
);
NVTE_CHECK
(
A
.
has_data
()
||
A
.
has_columnwise_data
(),
"Input A does not hold any data!"
);
NVTE_CHECK
(
A
.
has_data
()
||
A
.
has_columnwise_data
(),
"Input A does not hold any data!"
);
NVTE_CHECK
(
B
.
has_data
()
||
B
.
has_columnwise_data
(),
"Input B does not hold any data!"
);
NVTE_CHECK
(
B
.
has_data
()
||
B
.
has_columnwise_data
(),
"Input B does not hold any data!"
);
GemmParam
ret
(
transA
,
transB
)
;
GemmParam
ret
;
ret
.
lda
=
lda
;
// Transpose mode with column-major ordering
ret
.
ldb
=
ldb
;
bool
is_A_transposed
=
transA
==
CUBLAS_OP_T
;
bool
is_B_transposed
=
transB
==
CUBLAS_OP_T
;
// Configure A matrix
if
(
is_tensor_scaling
(
A
.
scaling_mode
))
{
if
(
is_tensor_scaling
(
A
.
scaling_mode
))
{
// Unscaled or FP8 tensor scaling
ret
.
A
=
A
.
data
.
dptr
;
ret
.
A
=
A
.
data
.
dptr
;
ret
.
transA
=
transA
;
ret
.
Atype
=
A
.
data
.
dtype
;
ret
.
A_scale_inv
=
A
.
scale_inv
.
dptr
;
ret
.
A_scale_inv
=
A
.
scale_inv
.
dptr
;
if
(
transA
==
CUBLAS_OP_T
)
{
ret
.
lda
=
is_A_transposed
?
k
:
m
;
ret
.
Atype
=
A
.
data
.
dtype
;
if
(
!
nvte_is_non_tn_fp8_gemm_supported
()
&&
!
is_A_transposed
)
{
}
else
{
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
ret
.
Atype
=
A
.
has_columnwise_data
()
?
A
.
columnwise_data
.
dtype
:
A
.
data
.
dtype
;
if
(
A
.
has_columnwise_data
()
&&
is_fp8_dtype
(
A
.
columnwise_data
.
dtype
))
{
if
(
is_fp8_dtype
(
ret
.
Atype
))
{
ret
.
A
=
A
.
columnwise_data
.
dptr
;
int
arch
=
cuda
::
sm_arch
(
cuda
::
current_device
());
ret
.
transA
=
CUBLAS_OP_T
;
if
(
arch
<
100
)
{
ret
.
Atype
=
A
.
columnwise_data
.
dtype
;
// Hopper and Ada - we need to use columnwise_data and change transA
ret
.
A_scale_inv
=
A
.
columnwise_scale_inv
.
dptr
;
NVTE_CHECK
(
A
.
has_columnwise_data
(),
"Input A is not suitable for columnwise usage!"
);
ret
.
lda
=
k
;
ret
.
A
=
A
.
columnwise_data
.
dptr
;
}
else
{
ret
.
transA
=
CUBLAS_OP_T
;
NVTE_CHECK
(
!
is_fp8_dtype
(
ret
.
Atype
),
"Input A is missing column-wise usage"
);
ret
.
A_scale_inv
=
A
.
columnwise_scale_inv
.
dptr
;
ret
.
lda
=
k
;
}
}
}
}
}
}
else
if
(
is_mxfp_scaling
(
A
.
scaling_mode
))
{
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if
(
is_A_transposed
)
{
NVTE_CHECK
(
A
.
has_data
(),
"Input A is missing row-wise usage"
);
}
else
{
NVTE_CHECK
(
A
.
has_columnwise_data
(),
"Input A is missing column-wise usage"
);
}
ret
.
A
=
is_A_transposed
?
A
.
data
.
dptr
:
A
.
columnwise_data
.
dptr
;
ret
.
transA
=
transA
;
ret
.
Atype
=
is_A_transposed
?
A
.
data
.
dtype
:
A
.
columnwise_data
.
dtype
;
ret
.
A_scale_inv
=
is_A_transposed
?
A
.
scale_inv
.
dptr
:
A
.
columnwise_scale_inv
.
dptr
;
ret
.
lda
=
is_A_transposed
?
k
:
m
;
}
else
if
(
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
A
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if
(
is_A_transposed
)
{
NVTE_CHECK
(
A
.
has_data
(),
"Input A is missing row-wise usage"
);
}
else
{
NVTE_CHECK
(
A
.
has_columnwise_data
(),
"Input A is missing column-wise usage"
);
}
ret
.
A
=
is_A_transposed
?
A
.
data
.
dptr
:
A
.
columnwise_data
.
dptr
;
ret
.
transA
=
CUBLAS_OP_T
;
ret
.
Atype
=
is_A_transposed
?
A
.
data
.
dtype
:
A
.
columnwise_data
.
dtype
;
ret
.
A_scale_inv
=
is_A_transposed
?
A
.
scale_inv
.
dptr
:
A
.
columnwise_scale_inv
.
dptr
;
ret
.
lda
=
k
;
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK
((
ret
.
lda
%
16
)
==
0
,
"Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK
((
m
%
8
)
==
0
,
"Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
}
else
{
NVTE_ERROR
(
"A has unsupported scaling mode"
);
}
// Configure B matrix
if
(
is_tensor_scaling
(
B
.
scaling_mode
))
{
// Unscaled or FP8 tensor scaling
ret
.
B
=
B
.
data
.
dptr
;
ret
.
B
=
B
.
data
.
dptr
;
ret
.
transB
=
transB
;
ret
.
Btype
=
B
.
data
.
dtype
;
ret
.
B_scale_inv
=
B
.
scale_inv
.
dptr
;
ret
.
B_scale_inv
=
B
.
scale_inv
.
dptr
;
if
(
transB
==
CUBLAS_OP_T
)
{
ret
.
ldb
=
is_B_transposed
?
n
:
k
;
ret
.
Btype
=
B
.
has_columnwise_data
()
?
B
.
columnwise_data
.
dtype
:
B
.
data
.
dtype
;
if
(
!
nvte_is_non_tn_fp8_gemm_supported
()
&&
is_B_transposed
)
{
if
(
is_fp8_dtype
(
ret
.
Btype
))
{
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
int
arch
=
cuda
::
sm_arch
(
cuda
::
current_device
());
if
(
B
.
has_columnwise_data
()
&&
is_fp8_dtype
(
B
.
columnwise_data
.
dtype
))
{
if
(
arch
<
100
)
{
ret
.
B
=
B
.
columnwise_data
.
dptr
;
// Hopper and Ada - we need to use columnwise_data and change transA
ret
.
transB
=
CUBLAS_OP_N
;
NVTE_CHECK
(
B
.
has_columnwise_data
(),
"Input B is not suitable for columnwise usage!"
);
ret
.
Btype
=
B
.
columnwise_data
.
dtype
;
ret
.
B
=
B
.
columnwise_data
.
dptr
;
ret
.
B_scale_inv
=
B
.
columnwise_scale_inv
.
dptr
;
ret
.
transB
=
CUBLAS_OP_N
;
ret
.
ldb
=
k
;
ret
.
B_scale_inv
=
B
.
columnwise_scale_inv
.
dptr
;
}
else
{
ret
.
ldb
=
k
;
NVTE_CHECK
(
!
is_fp8_dtype
(
ret
.
Btype
),
"Input B is missing column-wise usage"
);
}
}
}
}
}
else
if
(
is_mxfp_scaling
(
B
.
scaling_mode
))
{
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if
(
is_B_transposed
)
{
NVTE_CHECK
(
B
.
has_columnwise_data
(),
"Input B is missing column-wise usage"
);
}
else
{
}
else
{
ret
.
Btype
=
B
.
data
.
dtype
;
NVTE_CHECK
(
B
.
has_data
(),
"Input B is missing row-wise usage"
);
}
ret
.
B
=
is_B_transposed
?
B
.
columnwise_data
.
dptr
:
B
.
data
.
dptr
;
ret
.
transB
=
transB
;
ret
.
Btype
=
is_B_transposed
?
B
.
columnwise_data
.
dtype
:
B
.
data
.
dtype
;
ret
.
B_scale_inv
=
is_B_transposed
?
B
.
columnwise_scale_inv
.
dptr
:
B
.
scale_inv
.
dptr
;
ret
.
ldb
=
is_B_transposed
?
n
:
k
;
}
else
if
(
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
// FP8 block scaling
// Note: Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if
(
is_B_transposed
)
{
NVTE_CHECK
(
B
.
has_columnwise_data
(),
"Input B is missing column-wise usage"
);
}
else
{
NVTE_CHECK
(
B
.
has_data
(),
"Input B is missing row-wise usage"
);
}
ret
.
B
=
is_B_transposed
?
B
.
columnwise_data
.
dptr
:
B
.
data
.
dptr
;
ret
.
transB
=
CUBLAS_OP_N
;
ret
.
Btype
=
is_B_transposed
?
B
.
columnwise_data
.
dtype
:
B
.
data
.
dtype
;
ret
.
B_scale_inv
=
is_B_transposed
?
B
.
columnwise_scale_inv
.
dptr
:
B
.
scale_inv
.
dptr
;
ret
.
ldb
=
k
;
// Requirements from
// https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK
((
ret
.
ldb
%
16
)
==
0
,
"B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
if
(
B
.
scaling_mode
==
NVTE_BLOCK_SCALING_1D
)
{
// Observed this requirement only present for B tensor is 1D quantized.
NVTE_CHECK
((
n
%
8
)
==
0
,
"Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
}
}
}
else
{
}
else
{
// If not tensor scaling (which includes also high precision types), we need to
NVTE_ERROR
(
"B has unsupported scaling mode"
);
// use the proper version of data
// We leave the transA/B values as is, since Blackwell supports transposes
ret
.
A
=
transA
?
A
.
data
.
dptr
:
A
.
columnwise_data
.
dptr
;
ret
.
Atype
=
transA
?
A
.
data
.
dtype
:
A
.
columnwise_data
.
dtype
;
ret
.
A_scale_inv
=
transA
?
A
.
scale_inv
.
dptr
:
A
.
columnwise_scale_inv
.
dptr
;
ret
.
B
=
transB
?
B
.
columnwise_data
.
dptr
:
B
.
data
.
dptr
;
ret
.
Btype
=
transB
?
B
.
columnwise_data
.
dtype
:
B
.
data
.
dtype
;
ret
.
B_scale_inv
=
transB
?
B
.
columnwise_scale_inv
.
dptr
:
B
.
scale_inv
.
dptr
;
}
}
return
ret
;
return
ret
;
}
}
...
@@ -167,18 +245,33 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -167,18 +245,33 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#else // Use cublasLt
#else // Use cublasLt
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
cublasOperation_t
transa
,
int
ldb
,
int
ldd
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
cublasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
// Tensor dims in row-major order
const
int
A0
=
inputA
->
flat_first_dim
();
const
int
A1
=
inputA
->
flat_last_dim
();
const
int
B0
=
inputB
->
flat_first_dim
();
const
int
B1
=
inputB
->
flat_last_dim
();
// GEMM dims in column-major order
const
int
m
=
transa
==
CUBLAS_OP_T
?
A0
:
A1
;
const
int
n
=
transb
==
CUBLAS_OP_T
?
B1
:
B0
;
const
int
k
=
transa
==
CUBLAS_OP_T
?
A1
:
A0
;
NVTE_CHECK
((
transb
==
CUBLAS_OP_T
?
B0
:
B1
)
==
k
,
"GEMM inputs have incompatible dimensions (A is "
,
A0
,
"x"
,
A1
,
", B is "
,
B0
,
"x"
,
B1
,
")"
);
const
int
ldd
=
m
;
// Return immediately if GEMM is trivial
// Return immediately if GEMM is trivial
if
(
m
<=
0
||
n
<=
0
)
{
if
(
m
<=
0
||
n
<=
0
)
{
return
;
return
;
}
}
NVTE_CHECK
(
k
>
0
);
NVTE_CHECK
(
k
>
0
);
const
GemmParam
&
param
=
CanonicalizeGemmInput
(
*
inputA
,
transa
,
*
inputB
,
transb
,
k
,
lda
,
ldb
);
const
GemmParam
param
=
CanonicalizeGemmInput
(
*
inputA
,
transa
,
*
inputB
,
transb
,
m
,
n
,
k
);
void
*
C
=
outputD
->
data
.
dptr
;
void
*
C
=
outputD
->
data
.
dptr
;
void
*
D
=
outputD
->
data
.
dptr
;
void
*
D
=
outputD
->
data
.
dptr
;
void
*
D_scale
=
outputD
->
scale
.
dptr
;
void
*
D_scale
=
outputD
->
scale
.
dptr
;
...
@@ -240,6 +333,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -240,6 +333,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
param
.
transA
==
CUBLAS_OP_N
?
k
:
m
,
param
.
lda
));
param
.
transA
==
CUBLAS_OP_N
?
k
:
m
,
param
.
lda
));
NVTE_CHECK_CUBLAS
(
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
B_type
,
param
.
transB
==
CUBLAS_OP_N
?
k
:
n
,
NVTE_CHECK_CUBLAS
(
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
B_type
,
param
.
transB
==
CUBLAS_OP_N
?
k
:
n
,
param
.
transB
==
CUBLAS_OP_N
?
n
:
k
,
param
.
ldb
));
param
.
transB
==
CUBLAS_OP_N
?
n
:
k
,
param
.
ldb
));
NVTE_CHECK_CUBLAS
(
cublasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_CUBLAS
(
cublasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
CUDA_R_32F
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
CUDA_R_32F
));
...
@@ -265,7 +359,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -265,7 +359,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Scaling factors.
// Scaling factors.
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t
scaling_mode
;
cublasLtMatmulMatrixScale_t
scaling_mode_a
;
cublasLtMatmulMatrixScale_t
scaling_mode_b
;
#endif
#endif
if
((
is_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_tensor_scaling
(
inputB
->
scaling_mode
)))
{
if
((
is_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_tensor_scaling
(
inputB
->
scaling_mode
)))
{
void
*
A_scale_inverse
=
param
.
A_scale_inv
;
void
*
A_scale_inverse
=
param
.
A_scale_inv
;
...
@@ -277,8 +372,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -277,8 +372,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
scaling_mode
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
scaling_mode_a
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
}
else
if
((
is_block_scaling
(
inputA
->
scaling_mode
)
&&
is_block_scaling
(
inputB
->
scaling_mode
)))
{
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
}
else
if
((
is_mxfp_scaling
(
inputA
->
scaling_mode
)
&&
is_mxfp_scaling
(
inputB
->
scaling_mode
)))
{
fp8e8m0
*
A_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
A_scale_inv
);
fp8e8m0
*
A_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
A_scale_inv
);
fp8e8m0
*
B_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
B_scale_inv
);
fp8e8m0
*
B_scale_inverse
=
reinterpret_cast
<
fp8e8m0
*>
(
param
.
B_scale_inv
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
...
@@ -287,7 +383,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -287,7 +383,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
scaling_mode
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0
;
scaling_mode_a
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0
;
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0
;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if
(
cublasLtGetVersion
()
<=
120803
)
{
if
(
cublasLtGetVersion
()
<=
120803
)
{
...
@@ -296,7 +393,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -296,7 +393,32 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc
,
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE
,
&
dummy_a_vec_stride
,
operationDesc
,
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE
,
&
dummy_a_vec_stride
,
sizeof
(
dummy_a_vec_stride
)));
sizeof
(
dummy_a_vec_stride
)));
}
}
#endif
}
else
if
((
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
&&
(
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
))
{
#if CUDA_VERSION >= 12090
float
*
A_scale_inverse
=
reinterpret_cast
<
float
*>
(
param
.
A_scale_inv
);
float
*
B_scale_inverse
=
reinterpret_cast
<
float
*>
(
param
.
B_scale_inv
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
A_scale_inverse
,
sizeof
(
A_scale_inverse
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
NVTE_CHECK
((
!
(
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
&&
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)),
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported, but got 2D by 2D"
);
scaling_mode_a
=
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
?
CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
:
CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F
;
scaling_mode_b
=
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
?
CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
:
CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F
;
#else
NVTE_ERROR
(
"FP8 block scaling requires CUDA 12.9+"
);
#endif // CUDA_VERSION >= 12090
#endif // CUDA_VERSION >= 12080
}
else
{
}
else
{
NVTE_ERROR
(
"Not implemented scaling modes: "
+
to_string
(
inputA
->
scaling_mode
)
+
" and "
+
NVTE_ERROR
(
"Not implemented scaling modes: "
+
to_string
(
inputA
->
scaling_mode
)
+
" and "
+
to_string
(
inputB
->
scaling_mode
)
+
"."
);
to_string
(
inputB
->
scaling_mode
)
+
"."
);
...
@@ -304,9 +426,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -304,9 +426,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE
,
&
scaling_mode
,
sizeof
(
scaling_mode
)));
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE
,
&
scaling_mode
_a
,
sizeof
(
scaling_mode
_a
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE
,
&
scaling_mode
,
sizeof
(
scaling_mode
)));
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE
,
&
scaling_mode
_b
,
sizeof
(
scaling_mode
_b
)));
#endif
#endif
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
// Accumulation mode not supported for FP8 output
// Accumulation mode not supported for FP8 output
...
@@ -316,8 +438,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -316,8 +438,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER
,
&
D_amax
,
sizeof
(
D_amax
)));
operationDesc
,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER
,
&
D_amax
,
sizeof
(
D_amax
)));
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
// NOTE: In all current cases where FP8 output is supported, the input is
operationDesc
,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE
,
&
scaling_mode
,
sizeof
(
scaling_mode
)));
// scaled identically to the output.
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
#endif
#endif
// For FP8 output, cuBLAS requires C_type to match bias_type and
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
// be FP16/BF16
...
@@ -375,6 +500,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -375,6 +500,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
(
aux_type
)));
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
(
aux_type
)));
}
}
if
((
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
)
||
(
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
))
{
NVTE_CHECK
((
epilogue
==
CUBLASLT_EPILOGUE_DEFAULT
||
epilogue
==
CUBLASLT_EPILOGUE_BIAS
||
epilogue
==
CUBLASLT_EPILOGUE_DGELU
),
"Epilogue requested outside of the available and tested cuBLAS functionality for "
"float8 block scaled GEMM"
);
}
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
&
epilogue
,
sizeof
(
epilogue
)));
...
@@ -422,7 +555,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
...
@@ -422,7 +555,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK
(
status
!=
CUBLAS_STATUS_NOT_SUPPORTED
,
NVTE_CHECK
(
status
!=
CUBLAS_STATUS_NOT_SUPPORTED
,
"Unable to find suitable cuBLAS GEMM algorithm"
);
"Unable to find suitable cuBLAS GEMM algorithm"
);
NVTE_CHECK_CUBLAS
(
status
);
NVTE_CHECK_CUBLAS
(
status
);
if
(
returnedResults
==
0
)
NVTE_ERROR
(
"Unable to find any suitable algorithms"
);
if
(
returnedResults
==
0
)
NVTE_ERROR
(
"Unable to find any suitable algorithms"
);
// D = alpha * (A * B) + beta * C
// D = alpha * (A * B) + beta * C
...
@@ -494,6 +626,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -494,6 +626,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
#ifdef __HIP_PLATFORM_AMD__
const
size_t
A0
=
inputA
->
flat_first_dim
();
const
size_t
A0
=
inputA
->
flat_first_dim
();
const
size_t
A1
=
inputA
->
flat_last_dim
();
const
size_t
A1
=
inputA
->
flat_last_dim
();
const
size_t
B0
=
inputB
->
flat_first_dim
();
const
size_t
B0
=
inputB
->
flat_first_dim
();
...
@@ -519,32 +652,13 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -519,32 +652,13 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR
(
"TT layout not allowed."
);
NVTE_ERROR
(
"TT layout not allowed."
);
}
}
#ifdef __HIP_PLATFORM_AMD__
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
#else
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
}
else
{
}
else
{
hipblas_gemm
(
inputA
,
hipblas_gemm
(
inputA
,
inputB
,
inputB
,
...
@@ -565,8 +679,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -565,8 +679,11 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
nullptr
,
nullptr
,
stream
);
stream
);
}
}
#endif //__HIP_PLATFORM_AMD__
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif //__HIP_PLATFORM_AMD__
}
}
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
void
nvte_cublas_atomic_gemm
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
@@ -596,7 +713,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -596,7 +713,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
"Atomic GEMM only supports delayed scaling."
);
"Atomic GEMM only supports delayed scaling."
);
#ifdef __HIP_PLATFORM_AMD__
const
int
m
=
transa
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
m
=
transa
?
inputA
->
data
.
shape
[
0
]
:
inputA
->
data
.
shape
[
1
];
const
int
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
];
const
int
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
];
const
int
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
];
...
@@ -617,32 +734,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -617,32 +734,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
NVTE_ERROR
(
"TT layout not allowed."
);
NVTE_ERROR
(
"TT layout not allowed."
);
}
}
#ifdef __HIP_PLATFORM_AMD__
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
)){
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
#else
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
#endif //__HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
transa
,
transb
,
#else
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
#endif //__HIP_PLATFORM_AMD__
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
#ifdef __HIP_PLATFORM_AMD__
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
#else
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif
#ifdef __HIP_PLATFORM_AMD__
}
else
{
}
else
{
hipblas_gemm
(
inputA
,
hipblas_gemm
(
inputA
,
inputB
,
inputB
,
...
@@ -663,8 +761,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -663,8 +761,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
inputCounter
,
inputCounter
,
stream
);
stream
);
}
}
#endif //__HIP_PLATFORM_AMD__
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif //__HIP_PLATFORM_AMD__
}
}
...
...
transformer_engine/common/include/transformer_engine/cast.h
View file @
ab3e5a92
...
@@ -17,22 +17,31 @@
...
@@ -17,22 +17,31 @@
extern
"C"
{
extern
"C"
{
#endif
#endif
/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer)
/* Quantize the tensor
* The implementation is per the microscaling format MXFP8 defined by the OCP specification:
*
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor.
*
* Supported formats are:
*
* 1) MXFP8 scaling (for compute capability 10.0 or newer)
*
* The MXFP8 implementation is per the microscaling format MXFP8 defined by the OCP specification:
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
*
*
* Supported modes of scaling (live scaling):
*
* 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* Supported modes of MXFP8 scaling (live scaling) for scaling mode NVTE_MXFP8_1D_SCALING
* a) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* - the scaled output tensor
* - the scaled output tensor
* - the corresponding scaling factors
* - the corresponding scaling factors
* The scaling factors are computed for blocks of the shape [1,32]
* The scaling factors are computed for blocks of the shape [1,32]
* (i.e., each scaling factor spans 32 contiguous elements along rows).
* (i.e., each scaling factor spans 32 contiguous elements along rows).
*
*
*
2
) Columwise scaling (along the dim=1) computes one set of the output data.
*
b
) Columwise scaling (along the dim=1) computes one set of the output data.
* The scaling factors are computed for blocks of the shape [32,1]
* The scaling factors are computed for blocks of the shape [32,1]
* (i.e., each scaling factor spans 32 contiguous elements along columns).
* (i.e., each scaling factor spans 32 contiguous elements along columns).
*
*
*
3
) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
*
c
) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
* computes two sets of the output data: both 1) and 2).
* computes two sets of the output data: both 1) and 2).
*
*
* The shape of the MX block must be specified in the 'output' argument,
* The shape of the MX block must be specified in the 'output' argument,
...
@@ -40,31 +49,69 @@ extern "C" {
...
@@ -40,31 +49,69 @@ extern "C" {
*
*
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* of the output tensor should be set to 0.
* of the output tensor should be set to 0.
*
* 2) NVTE_DELAYED_TENSOR_SCALING that quantize the entire tensor
* using a single scaling factor. The absolute maximum value of the tensor should
* be precalculated either online (current scaling) or based on a tensor history
* (delayed scaling). The calls to nvte_quantize scale based on that data value.
* Note the NVTE_DELAYED_TENSOR_SCALING NVTEScalingMode is reused for online
* per tensor scaling.
*
*
* 3) FP8 block scaling formats NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D
* for compute capability of at least 9.0. These modes quantize the tensor by blocks
* of size 1x128 (with columnwise mode of 128x1) and 128x128 respectively.
*
* The supported modes are:
* a) Rowwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with identical shape to the
* input tensor.
* - Scale factors which are computed for either 1D 1x128 or 2D 128x128 blocks.
* b) Columnwise scaling yields output data:
* - the scaled output tensor in fp8 coefficients with a shape equivalent to
* the transpose of the input tensor.
* - Scale factors which are calculated for either 1D 128x1 or 2D 128x128 blocks
* of the input tensor.
* c) Both: In which both tensors and both scales are calculated.
*
* This quantization mode includes both the calculation of the scaling factors
* per-tile and quantization of the row and/or columnwise tiles. No precalculated
* absolute max is required. The scaling factors are also rounded to powers of 2.
*/
*/
/*! \brief Casts input tensor to FP8/MXFP8.
/*! \brief Casts input tensor to FP8/MXFP8
/BlockwiseFP8
.
*
If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
*
The type of quantized tensor in the output depends on the scaling mode of the output
* t
he block quantization (MXFP8) of the specified shape of the block will be used
.
* t
ensor. See file level comments
.
*
*
* \param[in] input Input tensor to be cast.
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[in,out] output Output FP8/MXFP8
/BlockwiseFP8
tensor.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_quantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
void
nvte_quantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel
/*! \brief Casts input tensor to FP8/MXFP8
/BlockwiseFP8
, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* based on the value of the 'noop' tensor.
*
If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
*
The type of quantized tensor in the output depends on the scaling mode of the output
* t
he block quantization (MXFP8) of the specified shape of the block will be used
.
* t
ensor. See file level comments
.
*
*
* \param[in] input Input tensor to be cast.
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output
FP8/MXFP8
tensor.
* \param[in,out] output Output
quantized
tensor.
* \param[out] noop Noop tensor.
* \param[out] noop Noop tensor.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output quantized tensor.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_quantize_v2
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
);
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
transformer_engine/common/include/transformer_engine/fused_rope.h
View file @
ab3e5a92
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "fused_attn.h"
#include "transformer_engine.h"
#include "transformer_engine.h"
#ifdef __cplusplus
#ifdef __cplusplus
...
@@ -16,112 +17,63 @@ extern "C" {
...
@@ -16,112 +17,63 @@ extern "C" {
/*! \brief Apply rotary positional embedding to the input tensor.
/*! \brief Apply rotary positional embedding to the input tensor.
*
*
* \param[in] input Input tensor for fused rope.
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s
Stride of the s dimension of input.
* \param[in] stride_s
_or_t
Stride of the s
(sbhd/bshd)/t (thd)
dimension of input.
* \param[in] stride_b Stride of the b dimension of input.
* \param[in] stride_b Stride of the b dimension of input.
(0 for thd).
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
freqs
,
NVTETensor
output
,
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
);
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the fused rope.
/*! \brief Compute the backward of the fused rope.
*
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s
Stride of the s dimension of output_grads.
* \param[in] stride_s
_or_t
Stride of the s
(sbhd/bshd)/t (thd)
dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads.
(0 for thd).
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
* \param[in] stream CUDA stream used for the operation.
*/
*/
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
freqs
,
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
NVTETensor
input_grads
,
const
int
s
,
const
int
b
,
const
int
h
,
const
NVTETensor
freqs
,
NVTETensor
input_grads
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_rope_thd_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_rope_thd_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
input_grads
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/normalization.h
View file @
ab3e5a92
...
@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
...
@@ -149,6 +149,16 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
);
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
);
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
);
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
);
/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
* in weight dtype. If set to false, it will compute in compute dtype.
*
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in] bool Enable if True
*/
void
nvte_enable_zero_centered_gamma_in_weight_dtype
(
bool
enable
);
enum
class
NVTE_Norm_Type
{
LayerNorm
,
RMSNorm
};
enum
class
NVTE_Norm_Type
{
LayerNorm
,
RMSNorm
};
#ifdef __cplusplus
#ifdef __cplusplus
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
ab3e5a92
...
@@ -42,6 +42,8 @@ struct NVTEShape {
...
@@ -42,6 +42,8 @@ struct NVTEShape {
const
size_t
*
data
;
const
size_t
*
data
;
/*! \brief Number of dimensions. */
/*! \brief Number of dimensions. */
size_t
ndim
;
size_t
ndim
;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t
owned_data
[
14
];
};
};
/*! \struct NVTEBasicTensor
/*! \struct NVTEBasicTensor
...
@@ -80,8 +82,13 @@ enum NVTEScalingMode {
...
@@ -80,8 +82,13 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either
/*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */
rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING
=
1
,
NVTE_MXFP8_1D_SCALING
=
1
,
NVTE_INVALID_SCALING
=
2
,
/*! Tensor is split into NxN quantization tiles or 1xN quantization tiles,
NVTE_NO_SCALING
=
3
which each yield a scale. The block_scaling_dim property of the quantizer
selects the granularity.
*/
NVTE_BLOCK_SCALING_1D
=
2
,
NVTE_BLOCK_SCALING_2D
=
3
,
NVTE_INVALID_SCALING
=
100
};
};
/*! \brief TE Tensor type
/*! \brief TE Tensor type
...
@@ -129,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor);
...
@@ -129,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/
*/
void
*
nvte_tensor_columnwise_data
(
const
NVTETensor
tensor
);
void
*
nvte_tensor_columnwise_data
(
const
NVTETensor
tensor
);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
NVTEShape
nvte_make_shape
(
const
size_t
*
data
,
size_t
ndim
);
/*! \brief Get a tensor's data shape.
/*! \brief Get a tensor's data shape.
*
*
* \param[in] tensor Tensor.
* \param[in] tensor Tensor.
...
@@ -281,6 +297,12 @@ enum NVTEQuantizationConfigAttribute {
...
@@ -281,6 +297,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigForcePow2Scales
=
0
,
kNVTEQuantizationConfigForcePow2Scales
=
0
,
/*! Small value to add to amax for numerical stability */
/*! Small value to add to amax for numerical stability */
kNVTEQuantizationConfigAmaxEpsilon
=
1
,
kNVTEQuantizationConfigAmaxEpsilon
=
1
,
/*! Noop tensor (containing a scalar).
If the scalar element value = 1, quantization kernel will early exit.
This is a tensor because the flag must be on GPU in order to enable
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor
=
2
,
kNVTEQuantizationConfigNumAttributes
kNVTEQuantizationConfigNumAttributes
};
};
...
@@ -406,8 +428,9 @@ class TensorWrapper {
...
@@ -406,8 +428,9 @@ class TensorWrapper {
float
*
amax_dptr
=
nullptr
,
float
*
scale_dptr
=
nullptr
,
float
*
amax_dptr
=
nullptr
,
float
*
scale_dptr
=
nullptr
,
float
*
scale_inv_dptr
=
nullptr
,
const
std
::
vector
<
size_t
>
&
scale_inv_shape
=
{
1
},
float
*
scale_inv_dptr
=
nullptr
,
const
std
::
vector
<
size_t
>
&
scale_inv_shape
=
{
1
},
const
NVTEScalingMode
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
)
const
NVTEScalingMode
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
)
:
TensorWrapper
(
dptr
,
NVTEShape
{
shape
.
data
(),
shape
.
size
()},
dtype
,
amax_dptr
,
scale_dptr
,
:
TensorWrapper
(
dptr
,
nvte_make_shape
(
shape
.
data
(),
shape
.
size
()),
dtype
,
amax_dptr
,
scale_inv_dptr
,
NVTEShape
{
scale_inv_shape
.
data
(),
scale_inv_shape
.
size
()},
scale_dptr
,
scale_inv_dptr
,
nvte_make_shape
(
scale_inv_shape
.
data
(),
scale_inv_shape
.
size
()),
scaling_mode
)
{}
scaling_mode
)
{}
/*! \brief Constructs new empty TensorWrapper.
/*! \brief Constructs new empty TensorWrapper.
...
@@ -523,7 +546,9 @@ class TensorWrapper {
...
@@ -523,7 +546,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper.
* \return Shape of this TensorWrapper.
*/
*/
const
NVTEShape
shape
()
const
noexcept
{
const
NVTEShape
shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
return
NVTEShape
{
nullptr
,
0
};
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
}
return
nvte_tensor_shape
(
tensor_
);
return
nvte_tensor_shape
(
tensor_
);
}
}
...
@@ -532,7 +557,9 @@ class TensorWrapper {
...
@@ -532,7 +557,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper.
* \return Shape of this TensorWrapper.
*/
*/
const
NVTEShape
columnwise_shape
()
const
noexcept
{
const
NVTEShape
columnwise_shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
return
NVTEShape
{
nullptr
,
0
};
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
}
return
nvte_tensor_columnwise_shape
(
tensor_
);
return
nvte_tensor_columnwise_shape
(
tensor_
);
}
}
...
@@ -645,7 +672,9 @@ class TensorWrapper {
...
@@ -645,7 +672,9 @@ class TensorWrapper {
* \return scale_inv_shape of this TensorWrapper.
* \return scale_inv_shape of this TensorWrapper.
*/
*/
const
NVTEShape
scale_inv_shape
()
const
noexcept
{
const
NVTEShape
scale_inv_shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
return
NVTEShape
{
nullptr
,
0
};
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
}
return
nvte_tensor_scale_inv_shape
(
tensor_
);
return
nvte_tensor_scale_inv_shape
(
tensor_
);
}
}
...
@@ -661,12 +690,20 @@ class TensorWrapper {
...
@@ -661,12 +690,20 @@ class TensorWrapper {
void
zero_
(
cudaStream_t
stream
)
{
nvte_zero_tensor
(
tensor_
,
stream
);
}
void
zero_
(
cudaStream_t
stream
)
{
nvte_zero_tensor
(
tensor_
,
stream
);
}
static
constexpr
size_t
defaultData
=
1
;
static
constexpr
size_t
defaultData
=
1
;
static
constexpr
NVTEShape
defaultShape
=
{
&
defaultData
,
1
};
static
constexpr
NVTEShape
defaultShape
=
{
&
defaultData
,
1
,
{
defaultData
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}};
private:
private:
NVTEShape
convertShape
(
const
NVTEShape
&
s
)
{
return
s
;
}
NVTEShape
convertShape
(
const
NVTEShape
&
s
)
{
NVTEShape
ret
=
s
;
// Move the ownership rather than pointing to the parent shape.
ret
.
data
=
ret
.
owned_data
;
return
ret
;
}
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>
&
s
)
{
return
{
s
.
data
(),
s
.
size
()};
}
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>
&
s
)
{
return
nvte_make_shape
(
s
.
data
(),
s
.
size
());
}
/*! \brief Wrapped NVTETensor. */
/*! \brief Wrapped NVTETensor. */
NVTETensor
tensor_
=
nullptr
;
NVTETensor
tensor_
=
nullptr
;
...
@@ -719,6 +756,12 @@ class QuantizationConfigWrapper {
...
@@ -719,6 +756,12 @@ class QuantizationConfigWrapper {
&
amax_epsilon
,
sizeof
(
float
));
&
amax_epsilon
,
sizeof
(
float
));
}
}
/*! \brief Set noop tensor pointer */
void
set_noop_tensor
(
NVTETensor
noop_tensor
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigNoopTensor
,
&
noop_tensor
,
sizeof
(
NVTETensor
));
}
private:
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig
config_
=
nullptr
;
NVTEQuantizationConfig
config_
=
nullptr
;
...
...
transformer_engine/common/libtransformer_engine.version
View file @
ab3e5a92
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
transformer_engine::is_fp8_dtype*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
*transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
};
};
local: *;
local: *;
};
};
transformer_engine/common/normalization/common.cpp
View file @
ab3e5a92
...
@@ -39,6 +39,8 @@ Compute always in FP32
...
@@ -39,6 +39,8 @@ Compute always in FP32
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
normalization
{
namespace
normalization
{
bool
&
use_zero_centered_gamma_in_weight_dtype
();
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend
::
NormFwdPhase_t
get_cudnn_forward_phase
(
const
bool
training
)
{
cudnn_frontend
::
NormFwdPhase_t
get_cudnn_forward_phase
(
const
bool
training
)
{
return
training
?
cudnn_frontend
::
NormFwdPhase_t
::
TRAINING
return
training
?
cudnn_frontend
::
NormFwdPhase_t
::
TRAINING
...
@@ -213,9 +215,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
...
@@ -213,9 +215,12 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
_ndim_scale_block
=
1
;
_ndim_scale_block
=
1
;
}
}
_scalar_dptr
=
std
::
make_unique
<
char
[]
>
(
typeToSize
(
wtype
));
const
auto
gamma_dtype
=
use_zero_centered_gamma_in_weight_dtype
()
?
wtype
:
ctype
;
_scalar_dptr
=
std
::
make_unique
<
char
[]
>
(
typeToSize
(
gamma_dtype
));
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
wtype
,
cpp_dtype
,
*
(
reinterpret_cast
<
cpp_dtype
*>
(
_scalar_dptr
.
get
()))
=
(
cpp_dtype
)
1.0
f
;);
gamma_dtype
,
cpp_dtype
,
*
(
reinterpret_cast
<
cpp_dtype
*>
(
_scalar_dptr
.
get
()))
=
(
cpp_dtype
)
1.0
f
;);
_handle
=
cudnnExecutionPlanManager
::
Instance
().
GetHandle
();
_handle
=
cudnnExecutionPlanManager
::
Instance
().
GetHandle
();
...
@@ -245,13 +250,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
...
@@ -245,13 +250,13 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.
set_name
(
"one"
)
.
set_name
(
"one"
)
.
set_dim
({
1
,
1
,
1
,
1
})
.
set_dim
({
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_data_type
(
get_cudnn_fe_dtype
(
w
type
))
.
set_data_type
(
get_cudnn_fe_dtype
(
gamma_d
type
))
.
set_is_pass_by_value
(
true
));
.
set_is_pass_by_value
(
true
));
auto
centered_options
=
fe
::
graph
::
Pointwise_attributes
()
auto
centered_options
=
fe
::
graph
::
Pointwise_attributes
()
.
set_mode
(
fe
::
PointwiseMode_t
::
ADD
)
.
set_mode
(
fe
::
PointwiseMode_t
::
ADD
)
.
set_compute_data_type
(
get_cudnn_fe_dtype
(
ctype
));
.
set_compute_data_type
(
get_cudnn_fe_dtype
(
ctype
));
_gamma
=
_graph
.
pointwise
(
_gamma_zero
,
_scalar_offset
,
centered_options
);
_gamma
=
_graph
.
pointwise
(
_gamma_zero
,
_scalar_offset
,
centered_options
);
_gamma
->
set_output
(
false
).
set_data_type
(
get_cudnn_fe_dtype
(
w
type
));
_gamma
->
set_output
(
false
).
set_data_type
(
get_cudnn_fe_dtype
(
gamma_d
type
));
}
else
{
}
else
{
_gamma
=
_gamma_zero
;
_gamma
=
_gamma_zero
;
}
}
...
@@ -537,6 +542,18 @@ bool& _cudnn_norm_bwd_flag() {
...
@@ -537,6 +542,18 @@ bool& _cudnn_norm_bwd_flag() {
bool
use_cudnn_norm_fwd
()
{
return
_cudnn_norm_fwd_flag
();
}
bool
use_cudnn_norm_fwd
()
{
return
_cudnn_norm_fwd_flag
();
}
bool
use_cudnn_norm_bwd
()
{
return
_cudnn_norm_bwd_flag
();
}
bool
use_cudnn_norm_bwd
()
{
return
_cudnn_norm_bwd_flag
();
}
bool
&
_zero_centered_gamma_in_weight_dtype
()
{
#ifdef USE_ROCM
static
bool
flag
=
false
;
return
flag
;
#else
static
bool
flag
=
transformer_engine
::
getenv
<
bool
>
(
"NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE"
);
return
flag
;
#endif
}
bool
&
use_zero_centered_gamma_in_weight_dtype
()
{
return
_zero_centered_gamma_in_weight_dtype
();
}
}
// namespace normalization
}
// namespace normalization
}
// namespace transformer_engine
}
// namespace transformer_engine
...
@@ -559,3 +576,13 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
...
@@ -559,3 +576,13 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
enable
;
transformer_engine
::
normalization
::
_cudnn_norm_bwd_flag
()
=
enable
;
#endif
#endif
}
}
void
nvte_enable_zero_centered_gamma_in_weight_dtype
(
bool
enable
)
{
NVTE_API_CALL
(
nvte_enable_zero_centered_gamma_in_weight_dtype
);
#ifdef USE_ROCM
bool
flag
=
false
;
transformer_engine
::
normalization
::
_zero_centered_gamma_in_weight_dtype
()
=
flag
;
#else
transformer_engine
::
normalization
::
_zero_centered_gamma_in_weight_dtype
()
=
enable
;
#endif
}
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
ab3e5a92
...
@@ -27,23 +27,28 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -27,23 +27,28 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_
block
_scaling
(
z
->
scaling_mode
))
{
!
is_
mxfp
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
}
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
);
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
NVTE_CHECK
(
gamma
.
data
.
shape
==
beta
.
data
.
shape
);
NVTE_CHECK
(
gamma
.
data
.
shape
==
beta
.
data
.
shape
,
"Gamma and Beta must have the same shape."
);
NVTE_CHECK
(
x
.
data
.
shape
[
1
]
==
gamma
.
data
.
shape
[
0
]);
NVTE_CHECK
(
gamma
.
data
.
dtype
==
beta
.
data
.
dtype
,
"Gamma and Beta must have the same dtype. Gamma dtype: "
+
to_string
(
gamma
.
data
.
dtype
)
+
", Beta dtype: "
+
to_string
(
beta
.
data
.
dtype
));
NVTE_CHECK
(
x
.
data
.
shape
[
1
]
==
gamma
.
data
.
shape
[
0
],
"Gamma must have the same hidden size."
);
NVTE_CHECK
(
epsilon
>=
0.
f
);
NVTE_CHECK
(
epsilon
>=
0.
f
,
"Epsilon must be non-negative."
);
NVTE_CHECK
(
z
->
data
.
shape
==
x
.
data
.
shape
);
NVTE_CHECK
(
z
->
data
.
shape
==
x
.
data
.
shape
,
"Output tensor must have the same shape as x."
);
NVTE_CHECK
(
mu
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]});
NVTE_CHECK
(
mu
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]},
NVTE_CHECK
(
mu
->
data
.
dtype
==
DType
::
kFloat32
);
"Mu must be 1D tensor with shape (x.shape[0],)."
);
NVTE_CHECK
(
mu
->
data
.
dtype
==
DType
::
kFloat32
,
"Mu must be a float32 tensor."
);
NVTE_CHECK
(
rsigma
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]});
NVTE_CHECK
(
rsigma
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]},
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
);
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
if
(
!
workspace
->
data
.
shape
.
empty
())
{
if
(
!
workspace
->
data
.
shape
.
empty
())
{
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
...
@@ -59,11 +64,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -59,11 +64,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool
is_aligned
=
true
;
bool
is_aligned
=
true
;
#ifdef USE_ROCM
#ifdef USE_ROCM
NVTE_CHECK
(
NVTE_CHECK
(
!
is_
block
_scaling
(
z
->
scaling_mode
),
!
is_
mxfp
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
block
_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
mxfp
_scaling
(
z
->
scaling_mode
);
#else
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
block
_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
mxfp
_scaling
(
z
->
scaling_mode
);
#endif
#endif
if
(
cudnn_backend
)
{
if
(
cudnn_backend
)
{
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
ab3e5a92
...
@@ -23,19 +23,20 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -23,19 +23,20 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_
block
_scaling
(
z
->
scaling_mode
))
{
!
is_
mxfp
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
}
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
);
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
NVTE_CHECK
(
gamma
.
data
.
shape
[
0
]
==
x
.
data
.
shape
[
1
]);
NVTE_CHECK
(
gamma
.
data
.
shape
[
0
]
==
x
.
data
.
shape
[
1
]
,
"Gamma must have the same hidden size."
);
NVTE_CHECK
(
epsilon
>=
0.
f
);
NVTE_CHECK
(
epsilon
>=
0.
f
,
"Epsilon must be non-negative."
);
NVTE_CHECK
(
z
->
data
.
shape
==
x
.
data
.
shape
);
NVTE_CHECK
(
z
->
data
.
shape
==
x
.
data
.
shape
,
"Output tensor must have the same shape as x."
);
NVTE_CHECK
(
rsigma
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]});
NVTE_CHECK
(
rsigma
->
data
.
shape
==
std
::
vector
<
size_t
>
{
x
.
data
.
shape
[
0
]},
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
);
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
if
(
!
workspace
->
data
.
shape
.
empty
())
{
if
(
!
workspace
->
data
.
shape
.
empty
())
{
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
...
@@ -49,11 +50,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -49,11 +50,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool
is_aligned
=
true
;
bool
is_aligned
=
true
;
#ifdef USE_ROCM
#ifdef USE_ROCM
NVTE_CHECK
(
NVTE_CHECK
(
!
is_
block
_scaling
(
z
->
scaling_mode
),
!
is_
mxfp
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by
block
scaling mode for normalization! Not surpported in rocm yet."
);
"Cudnn backend is need by
mxfp
scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
block
_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
mxfp
_scaling
(
z
->
scaling_mode
);
#else
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
block
_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_
mxfp
_scaling
(
z
->
scaling_mode
);
#endif
#endif
bool
training
=
bool
training
=
...
...
transformer_engine/common/nvshmem_api/CMakeLists.txt
0 → 100644
View file @
ab3e5a92
##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
cmake_minimum_required
(
VERSION 3.18
)
project
(
nvshmemapi LANGUAGES CXX CUDA
)
# Configure dependencies
find_package
(
CUDAToolkit REQUIRED
)
# find_package(MPI REQUIRED)
set
(
NVSHMEM_HOME
"$ENV{NVSHMEM_HOME}"
CACHE STRING
"Location of NVSHMEM installation"
)
add_library
(
nvshmemapi STATIC nvshmem_waitkernel.cu
)
set
(
NVSHMEMAPI_INCLUDE_DIR
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
PARENT_SCOPE
)
target_link_directories
(
nvshmemapi PUBLIC
${
NVSHMEM_HOME
}
/lib
)
target_link_libraries
(
nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver
)
target_include_directories
(
nvshmemapi PRIVATE
${
NVSHMEM_HOME
}
/include/
)
target_include_directories
(
nvshmemapi PUBLIC
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
set_target_properties
(
nvshmemapi PROPERTIES
CUDA_STANDARD 17
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON
)
\ No newline at end of file
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cuda_bf16.h>
#include <nvshmem.h>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>
#include "../util/logging.h"
#include "nvshmem_waitkernel.h"
__global__
void
__launch_bounds__
(
1
)
wait_until_on_stream_and_reset
(
uint64_t
*
wait_flag
,
uint64_t
wait_value
,
uint64_t
signal_reset
)
{
nvshmem_uint64_wait_until
(
wait_flag
,
NVSHMEM_CMP_EQ
,
wait_value
);
*
wait_flag
=
signal_reset
;
}
void
nvshmem_wait_on_stream
(
uint64_t
*
sig_addr
,
WaitKind
wait_kind
,
cudaStream_t
stream
)
{
uint64_t
wait_value
=
1
;
uint64_t
signal_reset
=
0
;
cudaStream_t
cur_stream
=
stream
;
NVTE_CHECK
(
wait_kind
>=
WaitKind
::
KERNEL_WAIT
&&
wait_kind
<=
WaitKind
::
STREAM_WAIT
,
"Invalid wait kind: "
,
static_cast
<
int
>
(
wait_kind
));
switch
(
wait_kind
)
{
case
WaitKind
::
KERNEL_WAIT
:
wait_until_on_stream_and_reset
<<<
1
,
1
,
0
,
cur_stream
>>>
(
sig_addr
,
wait_value
,
signal_reset
);
break
;
case
WaitKind
::
NVSHMEM_WAIT
:
nvshmemx_uint64_wait_until_on_stream
(
sig_addr
,
NVSHMEM_CMP_EQ
,
wait_value
,
cur_stream
);
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
signal_reset
,
CU_STREAM_WRITE_VALUE_DEFAULT
);
break
;
case
WaitKind
::
STREAM_WAIT
:
cuStreamWaitValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
wait_value
,
CU_STREAM_WAIT_VALUE_GEQ
);
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
signal_reset
,
CU_STREAM_WRITE_VALUE_DEFAULT
);
break
;
}
}
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#ifdef __cplusplus
#include <cstdint>
extern
"C"
{
#else
#include <stdint.h>
#endif
/*! \enum WaitKind
* \brief Types of wait operations that can be performed.
*/
enum
class
WaitKind
{
KERNEL_WAIT
=
0
,
/*!< Wait using a CUDA kernel */
NVSHMEM_WAIT
=
1
,
/*!< Wait using NVSHMEM wait operation */
STREAM_WAIT
=
2
/*!< Wait using CUDA stream synchronization */
};
/*! \brief Wait on a signal until a certain condition is met.
*
* \param[in] sig_addr The address of the signal to wait on.
* \param[in] wait_kind The kind of wait to perform.
* \param[in] stream The stream to wait on.
*/
void
nvshmem_wait_on_stream
(
uint64_t
*
sig_addr
,
WaitKind
wait_kind
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
transformer_engine/common/permutation/permutation.cu
View file @
ab3e5a92
...
@@ -351,7 +351,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
...
@@ -351,7 +351,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
const
transformer_engine
::
Tensor
*
input_fwd_cu
=
const
transformer_engine
::
Tensor
*
input_fwd_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
input_fwd
);
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
input_fwd
);
TRANSFORMER_ENGINE_TYPE_SWITCH_
ALL
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
NON_FP8ONLY
(
input_cu
->
data
.
dtype
,
T
,
input_cu
->
data
.
dtype
,
T
,
nvte_permute_launcher
(
reinterpret_cast
<
const
T
*>
(
input_cu
->
data
.
dptr
),
nvte_permute_launcher
(
reinterpret_cast
<
const
T
*>
(
input_cu
->
data
.
dptr
),
reinterpret_cast
<
T
*>
(
output_cu
->
data
.
dptr
),
reinterpret_cast
<
T
*>
(
output_cu
->
data
.
dptr
),
...
@@ -377,7 +377,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
...
@@ -377,7 +377,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const
transformer_engine
::
Tensor
*
prob_cu
=
const
transformer_engine
::
Tensor
*
prob_cu
=
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
prob
);
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
prob
);
TRANSFORMER_ENGINE_TYPE_SWITCH_
ALL
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
NON_FP8ONLY
(
input_cu
->
data
.
dtype
,
T
,
input_cu
->
data
.
dtype
,
T
,
nvte_unpermute_launcher
(
reinterpret_cast
<
const
T
*>
(
input_cu
->
data
.
dptr
),
nvte_unpermute_launcher
(
reinterpret_cast
<
const
T
*>
(
input_cu
->
data
.
dptr
),
reinterpret_cast
<
T
*>
(
output_cu
->
data
.
dptr
),
reinterpret_cast
<
T
*>
(
output_cu
->
data
.
dptr
),
...
...
transformer_engine/common/recipe/__init__.py
View file @
ab3e5a92
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
"""This module provides predefined FP8 recipes."""
"""This module provides predefined FP8 recipes."""
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
import
warnings
import
os
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
pydantic.dataclasses
import
dataclass
from
pydantic.dataclasses
import
dataclass
...
@@ -81,6 +82,10 @@ class Recipe:
...
@@ -81,6 +82,10 @@ class Recipe:
"""Whether the given recipe is per-tensor scaling."""
"""Whether the given recipe is per-tensor scaling."""
return
isinstance
(
self
,
(
DelayedScaling
,
Float8CurrentScaling
))
return
isinstance
(
self
,
(
DelayedScaling
,
Float8CurrentScaling
))
def
float8_block_scaling
(
self
):
"""Whether the given recipe is float8 blockwise scaling."""
return
isinstance
(
self
,
Float8BlockScaling
)
@
dataclass
()
@
dataclass
()
class
DelayedScaling
(
Recipe
):
class
DelayedScaling
(
Recipe
):
...
@@ -287,3 +292,99 @@ class MXFP8BlockScaling(Recipe):
...
@@ -287,3 +292,99 @@ class MXFP8BlockScaling(Recipe):
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"margin=
{
self
.
margin
}
, format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
,"
return
f
"margin=
{
self
.
margin
}
, format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
,"
@
dataclass
()
class
Float8BlockScaling
(
Recipe
):
"""
Use block-wise scaling for FP8 tensors.
In this strategy, tensors are scaled in blockwise fashion. Values within
each block share a common scaling factor. The block dimensionality
can be configured. The scaling factors are float32 containers. They
will by default be constrained to powers of 2.
Since the scaling happens in a particular direction (either rowwise
or columnwise), the quantized tensor and its transpose are not numerically
equivalent. Due to this, when Transformer Engine needs both the FP8 tensor
and its transpose (e.g. to calculate both forward and backward pass),
during the quantization both versions are computed from the high precision
input to avoid double quantization errors.
NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
"""
use_f32_scales
:
bool
=
os
.
getenv
(
"NVTE_FP8_BLOCK_SCALING_FP32_SCALES"
,
"0"
)
==
"1"
fp8_format
:
Format
=
Format
.
E4M3
fp8_quant_fwd_inp
=
QParams
(
power_2_scale
=
not
use_f32_scales
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_weight
=
QParams
(
power_2_scale
=
not
use_f32_scales
,
amax_epsilon
=
0.0
)
fp8_quant_bwd_grad
=
QParams
(
power_2_scale
=
not
use_f32_scales
,
amax_epsilon
=
0.0
)
x_block_scaling_dim
:
int
=
1
w_block_scaling_dim
:
int
=
2
grad_block_scaling_dim
:
int
=
1
fp8_gemm_fprop
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
fp8_gemm_dgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
fp8_gemm_wgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
fp8_dpa
:
bool
=
False
fp8_mha
:
bool
=
False
def
__post_init__
(
self
)
->
None
:
assert
self
.
x_block_scaling_dim
in
[
1
,
2
],
"Only 1D or 2D blocks supported for x"
assert
self
.
w_block_scaling_dim
in
[
1
,
2
],
"Only 1D or 2D blocks supported for w"
assert
self
.
grad_block_scaling_dim
in
[
1
,
2
],
"Only 1D or 2D blocks supported for grad"
assert
not
(
self
.
x_block_scaling_dim
==
2
and
self
.
w_block_scaling_dim
==
2
),
"2D by 2D block gemm not supported."
assert
not
(
self
.
x_block_scaling_dim
==
2
and
self
.
grad_block_scaling_dim
==
2
),
"2D by 2D block gemm not supported."
assert
not
(
self
.
w_block_scaling_dim
==
2
and
self
.
grad_block_scaling_dim
==
2
),
"2D by 2D block gemm not supported."
assert
self
.
fp8_gemm_fprop
.
use_split_accumulator
,
"Split accumulator required for fprop."
assert
self
.
fp8_gemm_dgrad
.
use_split_accumulator
,
"Split accumulator required for dgrad."
assert
self
.
fp8_gemm_wgrad
.
use_split_accumulator
,
"Split accumulator required for wgrad."
def
__repr__
(
self
)
->
str
:
return
(
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_quant_fwd_inp=
{
self
.
fp8_quant_fwd_inp
}
, "
f
"fp8_quant_fwd_weight=
{
self
.
fp8_quant_fwd_weight
}
, "
f
"fp8_quant_bwd_grad=
{
self
.
fp8_quant_bwd_grad
}
, "
f
"x_block_scaling_dim=
{
self
.
x_block_scaling_dim
}
, "
f
"w_block_scaling_dim=
{
self
.
w_block_scaling_dim
}
, "
f
"grad_block_scaling_dim=
{
self
.
grad_block_scaling_dim
}
, "
f
"fp8_gemm_fprop=
{
self
.
fp8_gemm_fprop
}
, "
f
"fp8_gemm_dgrad=
{
self
.
fp8_gemm_dgrad
}
, "
f
"fp8_gemm_wgrad=
{
self
.
fp8_gemm_wgrad
}
, "
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
"
)
transformer_engine/common/recipe/current_scaling.cu
View file @
ab3e5a92
...
@@ -156,7 +156,8 @@ namespace {
...
@@ -156,7 +156,8 @@ namespace {
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
epsilon
)
{
const
float
epsilon
)
{
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
std
::
numeric_limits
<
float
>::
max
());
}
}
}
// namespace
}
// namespace
...
...
transformer_engine/common/recipe/recipe_common.cuh
View file @
ab3e5a92
...
@@ -7,19 +7,21 @@
...
@@ -7,19 +7,21 @@
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include
<limits>
#include
"common/common.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
__device__
__forceinline__
float
compute_scale_from_amax
(
float
amax
,
float
max_fp8
,
__device__
__forceinline__
float
compute_scale_from_amax
(
float
amax
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
bool
force_pow_2_scales
,
float
epsilon
,
float
value_for_inf
)
{
// NOTE: NAN amax evaluates false for <, handled further down.
if
(
amax
<
epsilon
)
{
if
(
amax
<
epsilon
)
{
amax
=
epsilon
;
amax
=
epsilon
;
}
}
float
scale
=
1.
f
;
float
scale
=
1.
f
;
if
(
isinf
(
amax
)
||
amax
==
0.
f
)
{
if
(
isinf
(
amax
)
||
amax
==
0.
f
||
isnan
(
amax
)
)
{
return
scale
;
return
scale
;
}
}
...
@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
...
@@ -32,18 +34,13 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
// the scale is not representable in FP32.
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
if
(
isinf
(
scale
))
{
// use fp32 max to represent the scale
// use fp32 max to represent the scale
scale
=
std
::
numeric_limits
<
float
>::
max
()
;
scale
=
value_for_inf
;
}
}
if
(
isnan
(
scale
))
{
scale
=
1.
f
;
}
if
(
force_pow_2_scales
)
{
if
(
force_pow_2_scales
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
scale_bits
&=
0xFF800000
;
scale_bits
&=
0xFF800000
;
// If the exponent was zero, we have a logic error.
// If the exponent was zero, we have a logic error.
__builtin_assume
(
scale_bits
!=
0
);
__builtin_assume
(
scale_bits
!=
0
||
scale
==
0.0
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
}
}
...
@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
...
@@ -51,6 +48,26 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f
return
scale
;
return
scale
;
}
}
// Calculate the quantization scale for an individual data element
// given the amax(abs(tile)) value for a given quantization tile.
//
//
// Arguments:
// IType: data type of the tensor being quantized (float or bf16)
// OType: quantized data type (e4m3 or e5m2)
// amax: The evaluation of amax(abs(tile)) for the quantization tile.
// eps: An epsilon used as a floor for amax.
// pow_2_scaling: Whether to force the scale to be a power of 2.
template
<
typename
IType
,
typename
OType
>
__device__
__forceinline__
float
compute_scale_from_types
(
const
float
amax
,
const
float
eps
,
const
float
pow_2_scaling
)
{
constexpr
float
fp8_max
=
TypeInfo
<
OType
>::
max_finite_value
;
// NOTE: We're relying on compute_scale_from_amax to have behavior where it
// clips the mantissa of the max_finite_value if power of 2 scaling applies.
constexpr
float
value_for_inf
=
TypeInfo
<
IType
>::
max_finite_value
;
return
compute_scale_from_amax
(
amax
,
fp8_max
,
pow_2_scaling
,
eps
,
value_for_inf
);
}
}
// namespace transformer_engine
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
transformer_engine/common/transformer_engine.cpp
View file @
ab3e5a92
...
@@ -211,53 +211,32 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
...
@@ -211,53 +211,32 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
->
dtype
());
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
)
->
dtype
());
}
}
NVTEShape
nvte_make_shape
(
const
size_t
*
data
,
size_t
ndim
)
{
NVTEShape
ret
;
if
(
ndim
==
0
)
{
ret
.
data
=
nullptr
;
ret
.
ndim
=
0
;
return
ret
;
}
NVTE_CHECK
(
ndim
<=
sizeof
(
ret
.
owned_data
)
/
sizeof
(
ret
.
owned_data
[
0
]),
"Too many dims for NVTEShape (requested: "
,
ndim
,
", max: "
,
sizeof
(
ret
.
owned_data
)
/
sizeof
(
ret
.
owned_data
[
0
]),
")"
);
std
::
copy
(
data
,
data
+
ndim
,
ret
.
owned_data
);
ret
.
data
=
ret
.
owned_data
;
ret
.
ndim
=
ndim
;
return
ret
;
}
NVTEShape
nvte_tensor_shape
(
const
NVTETensor
tensor
)
{
NVTEShape
nvte_tensor_shape
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
{
if
(
tensor
==
nullptr
)
{
NVTE_ERROR
(
"Invalid tensor"
);
NVTE_ERROR
(
"Invalid tensor"
);
}
}
NVTEShape
ret
;
// Determine tensor shape depending on tensor format
// Determine tensor shape depending on tensor format
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
switch
(
t
.
scaling_mode
)
{
std
::
vector
<
size_t
>
shape
=
t
.
shape
();
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
!
t
.
has_data
()
&&
t
.
has_columnwise_data
())
{
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto
&
shape_cache
=
const_cast
<
std
::
vector
<
size_t
>
&>
(
t
.
data
.
shape
);
shape_cache
.
clear
();
if
(
!
t
.
columnwise_data
.
shape
.
empty
())
{
for
(
size_t
i
=
1
;
i
<
t
.
columnwise_data
.
shape
.
size
();
i
++
)
{
shape_cache
.
push_back
(
t
.
columnwise_data
.
shape
[
i
]);
}
shape_cache
.
push_back
(
t
.
columnwise_data
.
shape
.
front
());
}
ret
.
data
=
shape_cache
.
data
();
ret
.
ndim
=
shape_cache
.
size
();
}
else
{
ret
.
data
=
t
.
data
.
shape
.
data
();
ret
.
ndim
=
t
.
data
.
shape
.
size
();
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
if
(
!
t
.
has_data
()
&&
t
.
has_columnwise_data
())
{
ret
.
data
=
t
.
columnwise_data
.
shape
.
data
();
ret
.
ndim
=
t
.
columnwise_data
.
shape
.
size
();
}
else
{
ret
.
data
=
t
.
data
.
shape
.
data
();
ret
.
ndim
=
t
.
data
.
shape
.
size
();
}
break
;
}
default:
NVTE_ERROR
(
"Cannot parse tensor shape with scaling mode
\"
"
,
transformer_engine
::
to_string
(
t
.
scaling_mode
),
"
\"
"
);
}
return
ret
;
return
nvte_make_shape
(
shape
.
data
(),
shape
.
size
())
;
}
}
NVTEShape
nvte_tensor_columnwise_shape
(
const
NVTETensor
tensor
)
{
NVTEShape
nvte_tensor_columnwise_shape
(
const
NVTETensor
tensor
)
{
...
@@ -265,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
...
@@ -265,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
NVTE_ERROR
(
"Invalid tensor"
);
NVTE_ERROR
(
"Invalid tensor"
);
}
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
NVTEShape
ret
;
return
nvte_make_shape
(
t
.
columnwise_data
.
shape
.
data
(),
t
.
columnwise_data
.
shape
.
size
());
ret
.
data
=
t
.
columnwise_data
.
shape
.
data
();
ret
.
ndim
=
t
.
columnwise_data
.
shape
.
size
();
return
ret
;
}
}
size_t
nvte_tensor_ndims
(
const
NVTETensor
tensor
)
{
return
nvte_tensor_shape
(
tensor
).
ndim
;
}
size_t
nvte_tensor_ndims
(
const
NVTETensor
tensor
)
{
return
nvte_tensor_shape
(
tensor
).
ndim
;
}
...
@@ -292,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
...
@@ -292,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
size_t
nvte_tensor_element_size
(
const
NVTETensor
tensor
)
{
size_t
nvte_tensor_element_size
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
sizeof
(
float
);
if
(
tensor
==
nullptr
)
return
sizeof
(
float
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
return
transformer_engine
::
typeToSize
(
t
.
data
.
dtype
);
return
transformer_engine
::
typeToSize
(
t
.
dtype
()
);
}
}
void
*
nvte_tensor_data
(
const
NVTETensor
tensor
)
{
void
*
nvte_tensor_data
(
const
NVTETensor
tensor
)
{
...
@@ -336,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
...
@@ -336,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
}
}
NVTEShape
nvte_tensor_scale_inv_shape
(
const
NVTETensor
tensor
)
{
NVTEShape
nvte_tensor_scale_inv_shape
(
const
NVTETensor
tensor
)
{
if
(
tensor
==
nullptr
)
return
{
nullptr
,
0
};
if
(
tensor
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
);
}
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
const
auto
&
t
=
*
reinterpret_cast
<
const
transformer_engine
::
Tensor
*>
(
tensor
);
NVTEShape
ret
;
return
nvte_make_shape
(
t
.
scale_inv
.
shape
.
data
(),
t
.
scale_inv
.
shape
.
size
());
ret
.
data
=
t
.
scale_inv
.
shape
.
data
();
ret
.
ndim
=
t
.
scale_inv
.
shape
.
size
();
return
ret
;
}
}
void
nvte_set_tensor_param
(
NVTETensor
*
tensor
,
NVTETensorParam
param_name
,
void
nvte_set_tensor_param
(
NVTETensor
*
tensor
,
NVTETensorParam
param_name
,
...
@@ -463,6 +438,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
...
@@ -463,6 +438,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigAmaxEpsilon
:
case
kNVTEQuantizationConfigAmaxEpsilon
:
std
::
memcpy
(
buf
,
&
config_
.
amax_epsilon
,
attr_size
);
std
::
memcpy
(
buf
,
&
config_
.
amax_epsilon
,
attr_size
);
break
;
break
;
case
kNVTEQuantizationConfigNoopTensor
:
std
::
memcpy
(
buf
,
&
config_
.
noop_tensor
,
attr_size
);
break
;
default:
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
...
@@ -492,6 +470,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
...
@@ -492,6 +470,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigAmaxEpsilon
:
case
kNVTEQuantizationConfigAmaxEpsilon
:
std
::
memcpy
(
&
config_
.
amax_epsilon
,
buf
,
attr_size
);
std
::
memcpy
(
&
config_
.
amax_epsilon
,
buf
,
attr_size
);
break
;
break
;
case
kNVTEQuantizationConfigNoopTensor
:
std
::
memcpy
(
&
config_
.
noop_tensor
,
buf
,
attr_size
);
break
;
default:
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
...
...
transformer_engine/common/transpose/cast_transpose.h
View file @
ab3e5a92
...
@@ -23,6 +23,42 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType
...
@@ -23,6 +23,42 @@ template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType
void
dgated_act_cast_transpose
(
const
Tensor
&
input
,
const
Tensor
&
gated_act_input
,
Tensor
*
output
,
void
dgated_act_cast_transpose
(
const
Tensor
&
input
,
const
Tensor
&
gated_act_input
,
Tensor
*
output
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
quantize_transpose_square_blockwise
(
const
SimpleTensor
&
input
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_transpose
,
const
bool
pow_2_scale
,
cudaStream_t
stream
);
// enum class for rowwise usage
enum
class
FP8BlockwiseRowwiseOption
{
// No rowwise data
NONE
,
// Rowwise data, scales in GEMM format
ROWWISE
// TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum
class
FP8BlockwiseColumnwiseOption
{
// No columnwise data
NONE
,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
};
void
quantize_transpose_vector_blockwise
(
const
SimpleTensor
&
input
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scale
,
cudaStream_t
stream
);
}
// namespace transformer_engine::detail
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cuda/barrier>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \
(defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900)
#define TMA_HW_SUPPORTED
#endif
namespace
transformer_engine
{
namespace
{
// const values configuration
constexpr
size_t
kThreadsPerWarp
=
32
;
#ifdef TMA_HW_SUPPORTED
constexpr
size_t
BLOCK_TILE_DIM
=
128
;
constexpr
size_t
WARP_TILE_DIM_X
=
32
;
constexpr
size_t
WARP_TILE_DIM_Y
=
64
;
constexpr
size_t
THREAD_TILE_DIM_X
=
16
;
constexpr
size_t
THREAD_TILE_DIM_Y
=
4
;
#else
constexpr
size_t
BLOCK_TILE_DIM
=
128
;
constexpr
size_t
WARP_TILE_DIM_X
=
64
;
constexpr
size_t
WARP_TILE_DIM_Y
=
32
;
constexpr
size_t
THREAD_TILE_DIM_X
=
8
;
constexpr
size_t
THREAD_TILE_DIM_Y
=
8
;
#endif
#ifdef TMA_HW_SUPPORTED
constexpr
size_t
NUM_BYTES_PER_BANK
=
4
;
constexpr
size_t
NUM_BANKS_PER_SHARED_ELEM
=
THREAD_TILE_DIM_Y
/
NUM_BYTES_PER_BANK
;
constexpr
size_t
SHARED_BLOCK_TILE_DIM_Y
=
BLOCK_TILE_DIM
;
constexpr
size_t
SHARED_BLOCK_TILE_DIM_X_BANKS
=
BLOCK_TILE_DIM
/
(
NUM_BYTES_PER_BANK
*
NUM_BANKS_PER_SHARED_ELEM
);
constexpr
size_t
NUM_BANKS_Y_IN_WARP
=
WARP_TILE_DIM_Y
/
NUM_BYTES_PER_BANK
;
#endif
constexpr
size_t
ELE_PER_THREAD
=
THREAD_TILE_DIM_X
*
THREAD_TILE_DIM_Y
;
constexpr
size_t
THREADS_PER_BLOCK
=
BLOCK_TILE_DIM
*
BLOCK_TILE_DIM
/
ELE_PER_THREAD
;
constexpr
size_t
NUM_WARPS_X_IN_BLOCK
=
BLOCK_TILE_DIM
/
WARP_TILE_DIM_X
;
constexpr
size_t
NUM_WARPS_Y_IN_BLOCK
=
BLOCK_TILE_DIM
/
WARP_TILE_DIM_Y
;
constexpr
size_t
NUM_WARPS_IN_BLOCK
=
NUM_WARPS_X_IN_BLOCK
*
NUM_WARPS_Y_IN_BLOCK
;
constexpr
size_t
NUM_THREADS_X_IN_WARP
=
WARP_TILE_DIM_X
/
THREAD_TILE_DIM_X
;
constexpr
size_t
NUM_THREADS_Y_IN_WARP
=
kThreadsPerWarp
/
NUM_THREADS_X_IN_WARP
;
#define MIN(a, b) (a < b ? a : b)
template
<
bool
kReturnTranspose
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_BLOCK
)
block_scaled_cast_transpose_kernel
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
bool
pow_2_scaling
)
{
using
IVec
=
Vec
<
IType
,
THREAD_TILE_DIM_X
>
;
using
OVecCast
=
Vec
<
OType
,
THREAD_TILE_DIM_X
>
;
using
OVecTrans
=
Vec
<
OType
,
THREAD_TILE_DIM_Y
>
;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__
CType
block_tile_amax_shared
[
NUM_WARPS_IN_BLOCK
];
IVec
thrd_tile_input
[
THREAD_TILE_DIM_Y
];
constexpr
int
THREAD_TILE_DIM_X_
=
kReturnTranspose
?
THREAD_TILE_DIM_X
:
1
;
OVecTrans
thrd_tile_out_trans
[
THREAD_TILE_DIM_X_
];
const
int
tid_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
tid_in_warp_x
=
tid_in_warp
%
NUM_THREADS_X_IN_WARP
;
const
int
tid_in_warp_y
=
tid_in_warp
/
NUM_THREADS_X_IN_WARP
;
const
int
warp_id_in_block
=
threadIdx
.
x
/
kThreadsPerWarp
;
const
int
warp_id_in_block_x
=
warp_id_in_block
%
NUM_WARPS_X_IN_BLOCK
;
const
int
warp_id_in_block_y
=
warp_id_in_block
/
NUM_WARPS_X_IN_BLOCK
;
// This is ONLY true if the input is a full tile
const
int
tile_id_x
=
blockIdx
.
x
;
const
int
tile_id_y
=
blockIdx
.
y
;
const
size_t
block_tile_start_idx
=
tile_id_y
*
BLOCK_TILE_DIM
*
row_length
+
tile_id_x
*
BLOCK_TILE_DIM
;
const
size_t
warp_tile_start_idx
=
block_tile_start_idx
+
warp_id_in_block_y
*
THREAD_TILE_DIM_Y
*
NUM_THREADS_Y_IN_WARP
*
row_length
+
warp_id_in_block_x
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
;
const
size_t
thread_tile_start_idx
=
warp_tile_start_idx
+
tid_in_warp_y
*
THREAD_TILE_DIM_Y
*
row_length
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
;
CType
warp_tile_amax
;
CType
block_tile_amax
;
CType
block_tile_scale
;
CType
amax
=
0
;
// Step 1: Load a block tile of input data into thread tiles on registers
#pragma unroll
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
thrd_tile_input
[
i
].
load_from
(
input
+
thread_tile_start_idx
+
i
*
row_length
);
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])));
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if
(
tid_in_warp
==
0
)
{
block_tile_amax_shared
[
warp_id_in_block_y
*
NUM_WARPS_X_IN_BLOCK
+
warp_id_in_block_x
]
=
warp_tile_amax
;
}
__syncthreads
();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if
(
threadIdx
.
x
==
0
)
{
CType
blk_amax
=
block_tile_amax_shared
[
0
];
#pragma unroll
for
(
int
idx
=
1
;
idx
<
NUM_WARPS_IN_BLOCK
;
idx
++
)
{
blk_amax
=
fmaxf
(
blk_amax
,
block_tile_amax_shared
[
idx
]);
}
block_tile_amax_shared
[
0
]
=
blk_amax
;
}
__syncthreads
();
block_tile_amax
=
block_tile_amax_shared
[
0
];
block_tile_scale
=
compute_scale_from_types
<
IType
,
OType
>
(
block_tile_amax
,
epsilon
,
pow_2_scaling
);
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
CType
,
float
>::
value
);
const
CType
scale_inv
=
1.0
f
/
block_tile_scale
;
size_t
row_idx
=
tile_id_y
;
size_t
col_idx
=
tile_id_x
;
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
if
constexpr
(
kReturnTranspose
)
{
row_idx
=
tile_id_x
;
col_idx
=
tile_id_y
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast
tmp_output_c
;
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
thrd_tile_out_trans
[
j
].
data
.
elt
[
i
]
=
scaled_elt
;
}
}
tmp_output_c
.
store_to
(
output_c
+
thread_tile_start_idx
+
i
*
row_length
);
}
// Step 4: store transpose into shared memory
if
constexpr
(
kReturnTranspose
)
{
#ifdef TMA_HW_SUPPORTED
__shared__
alignas
(
128
)
OVecTrans
block_tile_trans_shared
[
SHARED_BLOCK_TILE_DIM_Y
][
SHARED_BLOCK_TILE_DIM_X_BANKS
];
OType
(
*
block_tile_trans_shared_otype_ptr
)[
BLOCK_TILE_DIM
]
=
reinterpret_cast
<
OType
(
*
)[
BLOCK_TILE_DIM
]
>
(
block_tile_trans_shared
);
#pragma unroll
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_X
;
i
++
)
{
auto
warp_id_in_block_x_
=
warp_id_in_block_y
;
auto
warp_id_in_block_y_
=
warp_id_in_block_x
;
int
row_idx
=
warp_id_in_block_y_
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
+
i
;
int
col_idx
=
warp_id_in_block_x_
*
(
NUM_BANKS_Y_IN_WARP
/
NUM_BANKS_PER_SHARED_ELEM
)
+
tid_in_warp_y
;
block_tile_trans_shared
[
row_idx
][
col_idx
]
=
thrd_tile_out_trans
[
i
];
}
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Step 5: store transpose output
// Initiate TMA transfer to copy shared memory to global memory
if
(
threadIdx
.
x
==
0
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
tile_id_y
*
BLOCK_TILE_DIM
,
tile_id_x
*
BLOCK_TILE_DIM
,
reinterpret_cast
<
uint64_t
*>
(
block_tile_trans_shared_otype_ptr
));
// Wait for TMA transfer to have finished reading shared memory.
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
// Wait for the group to have completed reading from shared memory.
ptx
::
cp_async_bulk_wait_group_read
<
0
>
();
}
#else
// Step 4 Alternative (when TMA is not available, skip writing to shared memory)
const
size_t
block_tile_t_start_idx
=
tile_id_x
*
BLOCK_TILE_DIM
*
num_rows
+
tile_id_y
*
BLOCK_TILE_DIM
;
const
size_t
warp_tile_t_start_idx
=
block_tile_t_start_idx
+
warp_id_in_block_x
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
*
num_rows
+
warp_id_in_block_y
*
THREAD_TILE_DIM_Y
*
NUM_THREADS_Y_IN_WARP
;
const
size_t
thread_tile_t_start_idx
=
warp_tile_t_start_idx
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
*
num_rows
+
tid_in_warp_y
*
THREAD_TILE_DIM_Y
;
#pragma unroll
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_X
;
i
++
)
{
thrd_tile_out_trans
[
i
].
store_to
(
output_t
+
thread_tile_t_start_idx
+
i
*
num_rows
);
}
#endif
}
}
template
<
bool
kReturnTranspose
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
THREADS_PER_BLOCK
)
block_scaled_cast_transpose_kernel_notaligned
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
bool
pow_2_scaling
)
{
using
IVec
=
Vec
<
IType
,
THREAD_TILE_DIM_X
>
;
using
OVecCast
=
Vec
<
OType
,
THREAD_TILE_DIM_X
>
;
using
OVecTrans
=
Vec
<
OType
,
THREAD_TILE_DIM_Y
>
;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__
CType
block_tile_amax_shared
[
NUM_WARPS_IN_BLOCK
];
IVec
thrd_tile_input
[
THREAD_TILE_DIM_Y
];
constexpr
int
THREAD_TILE_DIM_X_
=
kReturnTranspose
?
THREAD_TILE_DIM_X
:
1
;
OVecTrans
thrd_tile_out_trans
[
THREAD_TILE_DIM_X_
];
const
int
tid_in_warp
=
threadIdx
.
x
%
kThreadsPerWarp
;
const
int
tid_in_warp_x
=
tid_in_warp
%
NUM_THREADS_X_IN_WARP
;
const
int
tid_in_warp_y
=
tid_in_warp
/
NUM_THREADS_X_IN_WARP
;
const
int
warp_id_in_block
=
threadIdx
.
x
/
kThreadsPerWarp
;
const
int
warp_id_in_block_x
=
warp_id_in_block
%
NUM_WARPS_X_IN_BLOCK
;
const
int
warp_id_in_block_y
=
warp_id_in_block
/
NUM_WARPS_X_IN_BLOCK
;
const
int
tile_id_x
=
blockIdx
.
x
;
const
int
tile_id_y
=
blockIdx
.
y
;
const
size_t
block_tile_start_row_idx
=
tile_id_y
*
BLOCK_TILE_DIM
;
const
size_t
block_tile_start_col_idx
=
tile_id_x
*
BLOCK_TILE_DIM
;
const
size_t
block_tile_start_idx
=
block_tile_start_row_idx
*
row_length
+
block_tile_start_col_idx
;
const
size_t
warp_tile_start_idx
=
block_tile_start_idx
+
warp_id_in_block_y
*
THREAD_TILE_DIM_Y
*
NUM_THREADS_Y_IN_WARP
*
row_length
+
warp_id_in_block_x
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
;
const
size_t
thread_tile_start_idx
=
warp_tile_start_idx
+
tid_in_warp_y
*
THREAD_TILE_DIM_Y
*
row_length
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
;
// handle non-full tile
// check for three cases: full thread tile, nonfull thread tile, empty thread tile
// for empty thread tile, directly write zero to the transposed shared mem buffer
// for nonfull thread tile, fill zero to thread tile and act as if it's full
const
size_t
thread_tile_start_row_idx
=
tile_id_y
*
BLOCK_TILE_DIM
+
warp_id_in_block_y
*
THREAD_TILE_DIM_Y
*
NUM_THREADS_Y_IN_WARP
+
tid_in_warp_y
*
THREAD_TILE_DIM_Y
;
const
size_t
thread_tile_start_col_idx
=
tile_id_x
*
BLOCK_TILE_DIM
+
warp_id_in_block_x
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
;
const
size_t
thread_tile_end_row_idx
=
thread_tile_start_row_idx
+
THREAD_TILE_DIM_Y
-
1
;
const
size_t
thread_tile_end_col_idx
=
thread_tile_start_col_idx
+
THREAD_TILE_DIM_X
-
1
;
bool
full_thrd_tile
=
(
thread_tile_end_row_idx
<
num_rows
)
&&
(
thread_tile_end_col_idx
<
row_length
);
bool
empty_thrd_tile
=
(
thread_tile_start_row_idx
>=
num_rows
)
||
(
thread_tile_start_col_idx
>=
row_length
);
bool
nonfull_thrd_tile
=
(
!
full_thrd_tile
)
&&
(
!
empty_thrd_tile
);
const
size_t
thread_tile_ncols
=
MIN
(
THREAD_TILE_DIM_X
,
(
MIN
(
thread_tile_end_col_idx
,
row_length
-
1
)
-
thread_tile_start_col_idx
+
1
));
const
size_t
thread_tile_nrows
=
MIN
(
THREAD_TILE_DIM_Y
,
(
MIN
(
thread_tile_end_row_idx
,
num_rows
-
1
)
-
thread_tile_start_row_idx
+
1
));
CType
warp_tile_amax
;
CType
block_tile_amax
;
CType
block_tile_scale
;
CType
amax
=
0
;
if
(
!
empty_thrd_tile
)
{
// Step 1: Load a block tile of input data into thread tiles on registers
// Edge case: nonfull thread tile case, will use the partial load function here
if
(
nonfull_thrd_tile
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
if
(
i
>=
thread_tile_nrows
)
{
thrd_tile_input
[
i
].
clear
();
}
else
{
thrd_tile_input
[
i
].
load_from_elts
(
input
+
thread_tile_start_idx
+
i
*
row_length
,
0
,
thread_tile_ncols
);
}
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
thrd_tile_input
[
i
].
load_from_elts
(
input
+
thread_tile_start_idx
+
i
*
row_length
,
0
,
THREAD_TILE_DIM_X
);
}
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])));
}
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
amax
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
warp_tile_amax
=
__shfl_sync
(
0xFFFFFFFF
,
warp_tile_amax
,
lane_zero
);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if
(
tid_in_warp
==
0
)
{
block_tile_amax_shared
[
warp_id_in_block_y
*
NUM_WARPS_X_IN_BLOCK
+
warp_id_in_block_x
]
=
warp_tile_amax
;
}
__syncthreads
();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if
(
threadIdx
.
x
==
0
)
{
CType
blk_amax
=
block_tile_amax_shared
[
0
];
#pragma unroll
for
(
int
idx
=
1
;
idx
<
NUM_WARPS_IN_BLOCK
;
idx
++
)
{
blk_amax
=
fmaxf
(
blk_amax
,
block_tile_amax_shared
[
idx
]);
}
block_tile_amax_shared
[
0
]
=
blk_amax
;
}
__syncthreads
();
block_tile_amax
=
block_tile_amax_shared
[
0
];
block_tile_scale
=
compute_scale_from_types
<
IType
,
OType
>
(
block_tile_amax
,
epsilon
,
pow_2_scaling
);
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
CType
,
float
>::
value
);
const
CType
scale_inv
=
1.0
f
/
block_tile_scale
;
size_t
row_idx
=
tile_id_y
;
size_t
col_idx
=
tile_id_x
;
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
if
constexpr
(
kReturnTranspose
)
{
row_idx
=
tile_id_x
;
col_idx
=
tile_id_y
;
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
// for nonfull thread tile, pay attention when saving tmp_output_c to global
// memory, cannot vec store_to, but need to elt store to for empty tile,
// it should not enter this step, skip to Step 4
// set thrd_tile_out_trans to all zero
if
constexpr
(
kReturnTranspose
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
thrd_tile_out_trans
[
j
].
clear
();
}
}
if
(
!
empty_thrd_tile
)
{
OVecCast
tmp_output_c
;
for
(
int
i
=
0
;
i
<
THREAD_TILE_DIM_Y
;
i
++
)
{
if
(
i
>=
thread_tile_nrows
)
{
continue
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
THREAD_TILE_DIM_X
;
j
++
)
{
// Step 3: Store cast output
CType
scale_data
=
block_tile_scale
;
OType
scaled_elt
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
thrd_tile_input
[
i
].
data
.
elt
[
j
])
*
scale_data
);
tmp_output_c
.
data
.
elt
[
j
]
=
scaled_elt
;
// Step 4: do transpose within thread tile
if
constexpr
(
kReturnTranspose
)
{
thrd_tile_out_trans
[
j
].
data
.
elt
[
i
]
=
scaled_elt
;
}
}
tmp_output_c
.
store_to_elts
(
output_c
+
thread_tile_start_idx
+
i
*
row_length
,
0
,
thread_tile_ncols
);
}
if
constexpr
(
kReturnTranspose
)
{
const
size_t
block_tile_t_start_idx
=
tile_id_x
*
BLOCK_TILE_DIM
*
num_rows
+
tile_id_y
*
BLOCK_TILE_DIM
;
const
size_t
warp_tile_t_start_idx
=
block_tile_t_start_idx
+
warp_id_in_block_x
*
THREAD_TILE_DIM_X
*
NUM_THREADS_X_IN_WARP
*
num_rows
+
warp_id_in_block_y
*
THREAD_TILE_DIM_Y
*
NUM_THREADS_Y_IN_WARP
;
const
size_t
thread_tile_t_start_idx
=
warp_tile_t_start_idx
+
tid_in_warp_x
*
THREAD_TILE_DIM_X
*
num_rows
+
tid_in_warp_y
*
THREAD_TILE_DIM_Y
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_tile_ncols
;
i
++
)
{
thrd_tile_out_trans
[
i
].
store_to_elts
(
output_t
+
thread_tile_t_start_idx
+
i
*
num_rows
,
0
,
thread_tile_nrows
);
}
}
}
}
template
<
typename
OutputType
>
CUtensorMap
get_tensor_map
(
const
SimpleTensor
&
tensor
,
size_t
global_dim_x
,
size_t
global_dim_y
)
{
CUtensorMapDataType
dataType
;
if
constexpr
(
std
::
is_same_v
<
OutputType
,
__nv_fp8_e4m3
>
||
std
::
is_same_v
<
OutputType
,
__nv_fp8_e5m2
>
)
{
dataType
=
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_UINT8
;
}
else
{
NVTE_CHECK
(
false
,
"Invalid Output type (must be FP8)."
);
}
CUtensorMap
tensor_map_output_trans
{};
create_2D_tensor_map
(
tensor_map_output_trans
,
tensor
,
global_dim_y
,
global_dim_x
,
/*shmemY=*/
BLOCK_TILE_DIM
,
/*shmemX=*/
BLOCK_TILE_DIM
,
/*stride_elems=*/
global_dim_x
,
/*offset_elems=*/
0
,
sizeof
(
OutputType
));
return
tensor_map_output_trans
;
}
}
// namespace
}
// namespace transformer_engine
namespace
transformer_engine
::
detail
{
void
quantize_transpose_square_blockwise
(
const
SimpleTensor
&
input
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_transpose
,
const
bool
pow_2_scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
}
NVTE_CHECK
(
scale_inv
.
shape
.
size
()
==
2
,
"scale_inv must have 2 dimensions."
);
size_t
scale_k
=
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_x
=
1
;
const
size_t
scale_stride_y
=
scale_k
;
size_t
scale_t_stride_x
=
0
;
size_t
scale_t_stride_y
=
0
;
if
(
return_transpose
)
{
NVTE_CHECK
(
output_t
.
shape
.
size
()
==
input
.
shape
.
size
(),
"output_t must have same number of dimensions as input."
);
if
(
output_t
.
shape
.
size
()
>
0
)
{
NVTE_CHECK
(
output_t
.
shape
[
0
]
==
row_length
,
"Wrong dimension 0 of output_t."
);
for
(
size_t
i
=
1
;
i
<
output_t
.
shape
.
size
();
++
i
)
{
NVTE_CHECK
(
output_t
.
shape
.
at
(
i
)
==
input
.
shape
.
at
(
i
-
1
),
"Wrong dimension in output_t"
);
}
}
NVTE_CHECK
(
output
.
dtype
==
output_t
.
dtype
,
"output and output_t need to have the same type."
);
NVTE_CHECK
(
scale_inv_t
.
shape
.
size
()
==
2
,
"scale_inv_t must have 2 dimensions."
);
scale_t_stride_x
=
1
;
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
}
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
BLOCK_TILE_DIM
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
BLOCK_TILE_DIM
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output
.
dtype
,
OutputType
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
kReturnTranspose
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
BLOCK_TILE_DIM
==
0
&&
num_rows
%
BLOCK_TILE_DIM
==
0
;
if
(
full_tile
)
{
CUtensorMap
tensor_map_output_trans
;
if
(
return_transpose
)
{
tensor_map_output_trans
=
get_tensor_map
<
OutputType
>
(
output_t
,
num_rows
,
row_length
);
}
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
}
else
{
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
pow_2_scale
);
}
// full-tile
)
// return_transpose
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace transformer_engine::detail
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
{
using
transformer_engine
::
detail
::
FP8BlockwiseColumnwiseOption
;
using
transformer_engine
::
detail
::
FP8BlockwiseRowwiseOption
;
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr
size_t
kThreadsPerWarp
=
32
;
// Hyperparameters for performance tuning
constexpr
int
kTileDim
=
128
;
// Fixed to 128 beacause we are using 1x128 and 128x1 quantization
constexpr
int
kNVecIn
=
8
;
// The number of elements each LDG touches
constexpr
int
kNVecOut
=
16
;
// The number of elements each STG touches
constexpr
int
kNVecSMem
=
2
;
// The number of elements each LDS/STS touches
constexpr
int
kThreadsPerBlock
=
256
;
// Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert
(
kNVecIn
%
kNVecSMem
==
0
,
"kNVecIn must be divisible by kNVecSMem"
);
static_assert
(
kNVecOut
%
kNVecSMem
==
0
,
"kNVecOut must be divisible by kNVecSMem"
);
constexpr
int
kSMemRow
=
kTileDim
;
constexpr
int
kSMemCol
=
(
kTileDim
/
kNVecSMem
)
+
1
;
constexpr
int
kSMemSize
=
kSMemRow
*
kSMemCol
*
kNVecSMem
;
constexpr
int
kNumThreadsLoad
=
kTileDim
/
kNVecIn
;
constexpr
int
kNumThreadsStore
=
kTileDim
/
kNVecOut
;
static_assert
(
kNumThreadsLoad
<=
kThreadsPerWarp
,
"kNumThreadsLoad must be <= kThreadsPerWarp"
);
static_assert
(
kNumThreadsStore
<=
kThreadsPerWarp
,
"kNumThreadsStore must be <= kThreadsPerWarp"
);
template
<
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel
(
const
IType
*
const
input
,
OType
*
const
output_c
,
OType
*
const
output_t
,
CType
*
const
tile_scales_inv_c
,
CType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow_2_scaling
)
{
bool
return_rowwise
=
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
;
bool
return_columnwise_transpose
=
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecOut
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
;
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
row_length
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_c
[
r_g
*
row_length
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
CType
scale
;
// Step 2.4: Compute scale
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
);
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
output_vec
.
data
.
elt
[
i
*
kNVecSMem
+
j
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
j
])
*
scale
);
}
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
(
return_columnwise_transpose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
static_cast
<
size_t
>
(
blockIdx
.
y
)
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
),
num_rows
-
c_g
)
:
0
;
// For not aligned case
OType
*
output_g
=
&
output_t
[
r_g
*
num_rows
+
c_g
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsStore
*
kNumThreadsStore
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsStore
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsStore
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
// Step 3.3: Reduce amax
#pragma unroll
for
(
int
delta
=
kNumThreadsStore
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
// Step 3.4: Compute scale
CType
scale
;
scale
=
compute_scale_from_types
<
IType
,
OType
>
(
amax
,
epsilon
,
pow_2_scaling
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
}
if
(
write_scale_inv
)
{
CType
scale_inv
=
1.0
/
scale
;
size_t
row_idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
)
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
static_cast
<
size_t
>
(
blockIdx
.
y
);
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
output_vec
.
data
.
elt
[
i
]
=
static_cast
<
OType
>
(
static_cast
<
CType
>
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
])
*
scale
);
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g
+=
stride_g
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
}
// namespace
}
// namespace transformer_engine
namespace
transformer_engine
::
detail
{
void
quantize_transpose_vector_blockwise
(
const
SimpleTensor
&
input
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
FP8BlockwiseRowwiseOption
rowwise_option
,
FP8BlockwiseColumnwiseOption
columnwise_option
,
const
bool
pow2_scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise
);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
||
columnwise_option
!=
FP8BlockwiseColumnwiseOption
::
NONE
,
"rowwise_option and columnwise_option cannot both be NONE"
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
num_elements
*=
input
.
shape
.
at
(
i
);
}
// Early return if the input tensor is empty
if
(
num_elements
==
0
)
{
return
;
}
// Options for scale layout of cuBLAS GEMM kernel.
size_t
scale_stride_x
=
0
;
size_t
scale_stride_y
=
0
;
size_t
scale_t_stride_x
=
0
;
size_t
scale_t_stride_y
=
0
;
if
(
rowwise_option
!=
FP8BlockwiseRowwiseOption
::
NONE
)
{
NVTE_CHECK
(
rowwise_option
==
FP8BlockwiseRowwiseOption
::
ROWWISE
,
"Unexpected rowwise enum value"
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
NVTE_CHECK
(
scale_inv
.
shape
.
size
()
==
2
,
"Scale dimension must be 2."
);
size_t
scale_k
=
scale_inv
.
shape
[
1
];
scale_stride_x
=
scale_k
;
scale_stride_y
=
1
;
}
if
(
columnwise_option
!=
FP8BlockwiseColumnwiseOption
::
NONE
)
{
NVTE_CHECK
(
columnwise_option
==
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
,
"Unexpected columnwise enum value"
);
NVTE_CHECK
(
output_t
.
shape
.
size
()
==
input
.
shape
.
size
(),
"output_t must have same number of dimensions as input."
);
if
(
output_t
.
shape
.
size
()
>
0
)
{
NVTE_CHECK
(
output_t
.
shape
[
0
]
==
row_length
,
"Wrong dimension 0 of output_t."
);
for
(
size_t
i
=
1
;
i
<
output_t
.
shape
.
size
();
++
i
)
{
NVTE_CHECK
(
output_t
.
shape
.
at
(
i
)
==
input
.
shape
.
at
(
i
-
1
),
"Wrong dimension in output_t"
);
}
}
NVTE_CHECK
(
output
.
dtype
==
output_t
.
dtype
,
"output and output_t need to have the same dtype."
);
NVTE_CHECK
(
scale_inv_t
.
shape
.
size
()
==
2
,
"Scale_t dimension must be 2."
);
scale_t_stride_x
=
scale_inv_t
.
shape
[
1
];
scale_t_stride_y
=
1
;
}
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
kTileDim
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
kTileDim
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output
.
dtype
,
OutputType
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);)
// kAligned
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace transformer_engine::detail
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment