Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3480 additions
and
303 deletions
+3480
-303
transformer_engine/common/gemm/config.h
transformer_engine/common/gemm/config.h
+36
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+415
-57
transformer_engine/common/hadamard_transform/hadamard_transform.cu
...er_engine/common/hadamard_transform/hadamard_transform.cu
+876
-0
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+841
-0
transformer_engine/common/include/transformer_engine/activation.h
...mer_engine/common/include/transformer_engine/activation.h
+41
-0
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+25
-0
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+73
-37
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+182
-7
transformer_engine/common/include/transformer_engine/hadamard_transform.h
...ne/common/include/transformer_engine/hadamard_transform.h
+68
-0
transformer_engine/common/include/transformer_engine/recipe.h
...sformer_engine/common/include/transformer_engine/recipe.h
+4
-0
transformer_engine/common/include/transformer_engine/swizzle.h
...former_engine/common/include/transformer_engine/swizzle.h
+20
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+46
-4
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+4
-4
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+4
-4
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+174
-13
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+19
-8
transformer_engine/common/recipe/nvfp4.cu
transformer_engine/common/recipe/nvfp4.cu
+54
-0
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+194
-162
transformer_engine/common/swizzle/swizzle_block_scaling.cu
transformer_engine/common/swizzle/swizzle_block_scaling.cu
+321
-0
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+83
-7
No files found.
transformer_engine/common/gemm/config.h
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#include <transformer_engine/transformer_engine.h>
namespace
transformer_engine
{
struct
MatmulConfig
{
NVTETensor
bias_tensor
=
nullptr
;
NVTETensor
dbias_tensor
=
nullptr
;
bool
with_gelu_epilogue
=
false
;
bool
with_dgelu_epilogue
=
false
;
NVTETensor
epilogue_aux_tensor
=
nullptr
;
bool
use_split_accumulator
=
false
;
int
sm_count
=
0
;
static
constexpr
size_t
attr_sizes
[]
=
{
sizeof
(
NVTETensor
),
// bias_tensor
sizeof
(
NVTETensor
),
// dbias_tensor
sizeof
(
bool
),
// with_gelu_epilogue
sizeof
(
bool
),
// with_dgelu_epilogue
sizeof
(
NVTETensor
),
// epilogue_aux_tensor
sizeof
(
bool
),
// use_split_accumulator
sizeof
(
int
)
// sm_count
};
};
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
063ef88d
...
...
@@ -15,23 +15,58 @@
#endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <cstdint>
#include <mutex>
#include <vector>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "
common/util/cuda_runtime
.h"
#include "
./config
.h"
#ifndef __HIP_PLATFORM_AMD__
#include "cutlass_grouped_gemm.cuh"
#include "
./
cutlass_grouped_gemm.cuh"
#endif
#ifndef __HIP_PLATFORM_AMD__
namespace
{
/* Use CUDA const memory to store scalar 1 and 0 for cublas usage
*/
__device__
__constant__
float
one_device
;
__device__
__constant__
float
zero_device
;
inline
float
*
GetScalarOne
()
{
static
std
::
once_flag
init_flag
;
std
::
call_once
(
init_flag
,
[]()
{
float
one
=
1.0
f
;
NVTE_CHECK_CUDA
(
cudaMemcpyToSymbol
(
one_device
,
&
one
,
sizeof
(
float
)));
});
// return address by cudaGetSymbolAddress
float
*
dev_ptr
;
NVTE_CHECK_CUDA
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_ptr
),
one_device
));
return
dev_ptr
;
}
inline
float
*
GetScalarZero
()
{
static
std
::
once_flag
init_flag
;
std
::
call_once
(
init_flag
,
[]()
{
float
zero
=
0.0
f
;
NVTE_CHECK_CUDA
(
cudaMemcpyToSymbol
(
zero_device
,
&
zero
,
sizeof
(
float
)));
});
// return address by cudaGetSymbolAddress
float
*
dev_ptr
;
NVTE_CHECK_CUDA
(
cudaGetSymbolAddress
(
reinterpret_cast
<
void
**>
(
&
dev_ptr
),
zero_device
));
return
dev_ptr
;
}
__global__
__launch_bounds__
(
1
)
void
set_float_kernel
(
float
*
ptr
,
float
val
)
{
*
ptr
=
val
;
}
uint32_t
_getAlignment
(
uintptr_t
address
)
{
// alignment are in bytes
uint32_t
alignment
=
256
;
...
...
@@ -91,6 +126,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
bool
is_A_transposed
=
transA
==
CUBLAS_OP_T
;
bool
is_B_transposed
=
transB
==
CUBLAS_OP_T
;
// Set conditions for MXFP8 and NVFP4 gemm execution.
const
auto
nvfp4
=
is_nvfp_scaling
(
A
.
scaling_mode
)
&&
is_nvfp_scaling
(
B
.
scaling_mode
);
const
auto
mxfp8
=
!
nvfp4
&&
is_mxfp_scaling
(
A
.
scaling_mode
)
&&
is_mxfp_scaling
(
B
.
scaling_mode
);
// Configure A matrix
if
(
is_tensor_scaling
(
A
.
scaling_mode
))
{
// Unscaled or FP8 tensor scaling
...
...
@@ -111,10 +150,32 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK
(
!
is_fp8_dtype
(
ret
.
Atype
),
"Input A is missing column-wise usage"
);
}
}
}
else
if
(
is_mxfp_scaling
(
A
.
scaling_mode
))
{
// MXFP8
if
(
is_fp8_dtype
(
ret
.
Atype
))
{
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK
(
ret
.
lda
%
16
==
0
,
"Leading dimension requirement on A for FP8 GEMM. Caller must pad."
);
}
}
else
if
(
nvfp4
)
{
// NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.
if
(
is_A_transposed
)
{
NVTE_CHECK
(
A
.
has_data
(),
"Input A is missing row-wise usage"
);
}
else
{
NVTE_CHECK
(
is_nvfp4_scaling
(
A
.
scaling_mode
),
"Input A has unsupported combination of recipe and layout"
);
NVTE_CHECK
(
A
.
has_columnwise_data
(),
"Input A is missing column-wise usage"
);
}
ret
.
A
=
is_A_transposed
?
A
.
data
.
dptr
:
A
.
columnwise_data
.
dptr
;
ret
.
transA
=
CUBLAS_OP_T
;
// NVFP4 gemm is only supported in TN layout.
ret
.
Atype
=
is_A_transposed
?
A
.
data
.
dtype
:
A
.
columnwise_data
.
dtype
;
ret
.
A_scale_inv
=
is_A_transposed
?
A
.
scale_inv
.
dptr
:
A
.
columnwise_scale_inv
.
dptr
;
ret
.
lda
=
k
;
}
else
if
(
mxfp8
)
{
// MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe.
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if
(
is_A_transposed
)
{
NVTE_CHECK
(
A
.
has_data
(),
"Input A is missing row-wise usage"
);
}
else
{
...
...
@@ -141,7 +202,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK
((
ret
.
lda
%
16
)
==
0
,
"
Inner
dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
"
Leading
dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."
);
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK
((
m
%
8
)
==
0
,
...
...
@@ -170,10 +231,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK
(
!
is_fp8_dtype
(
ret
.
Btype
),
"Input B is missing column-wise usage"
);
}
}
}
else
if
(
is_mxfp_scaling
(
B
.
scaling_mode
))
{
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if
(
is_fp8_dtype
(
ret
.
Atype
))
{
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK
(
ret
.
ldb
%
16
==
0
,
"Leading dimension requirement on B for FP8 GEMM. Caller must pad."
);
}
}
else
if
(
nvfp4
)
{
if
(
is_B_transposed
)
{
NVTE_CHECK
(
is_nvfp4_scaling
(
B
.
scaling_mode
),
"Input B has unsupported combination of recipe and layout"
);
NVTE_CHECK
(
B
.
has_columnwise_data
(),
"Input B is missing column-wise usage"
);
}
else
{
NVTE_CHECK
(
B
.
has_data
(),
"Input B is missing row-wise usage"
);
}
ret
.
B
=
is_B_transposed
?
B
.
columnwise_data
.
dptr
:
B
.
data
.
dptr
;
ret
.
transB
=
CUBLAS_OP_N
;
// NVFP4 gemm is only supported in TN layout.
ret
.
Btype
=
is_B_transposed
?
B
.
columnwise_data
.
dtype
:
B
.
data
.
dtype
;
ret
.
B_scale_inv
=
is_B_transposed
?
B
.
columnwise_scale_inv
.
dptr
:
B
.
scale_inv
.
dptr
;
ret
.
ldb
=
k
;
}
else
if
(
mxfp8
)
{
if
(
is_B_transposed
)
{
NVTE_CHECK
(
B
.
has_columnwise_data
(),
"Input B is missing column-wise usage"
);
}
else
{
...
...
@@ -238,7 +315,7 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void
cublas_gemm
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
const
void
*
alpha
,
const
void
*
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
cudaStream_t
stream
)
{
// Tensor dims in row-major order
...
...
@@ -277,6 +354,49 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
param
.
Atype
)
||
is_fp8_dtype
(
param
.
Btype
);
const
bool
use_fp4
=
is_fp4_dtype
(
param
.
Atype
)
||
is_fp4_dtype
(
param
.
Btype
);
// Update scaling factors with NVFP4 tensor scales
// TODO: Check whether scales are on CPU/GPU or add API to control.
// Currently scales are assumed to be on CPU when amax is provided
// and on GPU when not provided, but this is brittle.
if
(
use_fp4
&&
(
inputA
->
amax
.
dptr
!=
nullptr
||
inputB
->
amax
.
dptr
!=
nullptr
))
{
// Reserve some workspace for alpha scale
NVTE_CHECK
(
workspaceSize
>=
4
,
"NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has "
,
workspaceSize
,
" bytes remaining."
);
workspaceSize
=
(
workspaceSize
/
4
)
*
4
-
4
;
// Remove last 4 aligned bytes
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
);
float
*
new_alpha_ptr
=
reinterpret_cast
<
float
*>
(
&
workspace_ptr
[
workspaceSize
]);
// Update alpha scale on device
// Note: Compute NVFP4 tensor scales based on amaxes and then
// divide from alpha scale. This way we only need to apply NVFP4
// tensor scales in matmul output, instead of in matmul inputs.
float
old_alpha
=
*
reinterpret_cast
<
const
float
*>
(
alpha
);
// Assumed to be on CPU
TensorWrapper
new_alpha_tensor
(
new_alpha_ptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
nvte_nvfp4_compute_per_tensor_scale
(
inputA
->
nvte_tensor
,
transa
,
inputB
->
nvte_tensor
,
!
transb
,
old_alpha
,
new_alpha_tensor
.
data
(),
stream
);
alpha
=
new_alpha_ptr
;
// Make sure beta scale is on device
float
old_beta
=
*
reinterpret_cast
<
const
float
*>
(
beta
);
// Assumed to be on CPU
if
(
old_beta
==
0
)
{
beta
=
GetScalarZero
();
// Device constant memory
}
else
if
(
old_beta
==
1
)
{
beta
=
GetScalarOne
();
// Device constant memory
}
else
{
// Move beta to workspace
NVTE_CHECK
(
workspaceSize
>=
4
,
"NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has "
,
workspaceSize
,
" bytes remaining."
);
workspaceSize
=
(
workspaceSize
/
4
)
*
4
-
4
;
// Remove last 4 aligned bytes
float
*
new_beta_ptr
=
reinterpret_cast
<
float
*>
(
&
workspace_ptr
[
workspaceSize
]);
set_float_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
new_beta_ptr
,
old_beta
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
beta
=
new_beta_ptr
;
}
}
const
cudaDataType_t
A_type
=
get_cuda_dtype
(
param
.
Atype
);
const
cudaDataType_t
B_type
=
get_cuda_dtype
(
param
.
Btype
);
...
...
@@ -287,16 +407,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp8_dtype
(
param
.
Btype
)
||
param
.
B_scale_inv
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp4_dtype
(
param
.
Atype
)
||
param
.
A_scale_inv
!=
nullptr
,
"FP4 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp4_dtype
(
param
.
Btype
)
||
param
.
B_scale_inv
!=
nullptr
,
"FP4 input to GEMM requires inverse of scale!"
);
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if
(
use_fp8
&&
gelu
)
{
if
(
(
use_fp8
||
use_fp4
)
&&
gelu
)
{
NVTE_CHECK
(
!
is_fp8_dtype
(
outputPreGelu
->
data
.
dtype
),
"fp8 Aux output for gemm + gelu fusion not supported!"
);
}
if
(
is_fp8_dtype
(
outputD
->
data
.
dtype
))
{
NVTE_CHECK
(
beta
==
0.0
f
,
"Accumulation mode not supported with FP8 GEMM output!"
);
if
(
is_fp4_dtype
(
outputD
->
data
.
dtype
))
{
NVTE_ERROR
(
"FP4 GEMM output is not supported!"
);
}
if
(
use_fp4
&&
(
D_type
==
CUDA_R_16F
))
{
NVTE_ERROR
(
"FP4 GEMM does not support FP16 output!"
);
}
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
...
...
@@ -336,12 +463,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&
math_sm_count
,
sizeof
(
math_sm_count
)));
}
// set fp8 attributes -- input and output types should already be set to fp8
as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// set fp8
/fp4
attributes -- input and output types should already be set to fp8
/fp4
//
as appropriate.
Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if
(
use_fp8
)
{
// Split accumulator.
const
int8_t
fastAccuMode
=
(
use_split_accumulator
)
?
0
:
1
;
const
bool
mxfp8_gemm
=
!
use_fp4
&&
is_mxfp8_scaling
(
inputA
->
scaling_mode
);
if
(
use_fp8
||
use_fp4
)
{
// Fast accumulation is only supported for FP8.
const
int8_t
fastAccuMode
=
(
use_split_accumulator
)
?
0
:
use_fp8
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_FAST_ACCUM
,
&
fastAccuMode
,
sizeof
(
fastAccuMode
)));
...
...
@@ -350,7 +479,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
cublasLtMatmulMatrixScale_t
scaling_mode_a
;
cublasLtMatmulMatrixScale_t
scaling_mode_b
;
#endif // CUBLAS_VERSION >= 120800
if
(
(
is_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_tensor_scaling
(
inputB
->
scaling_mode
))
)
{
if
(
is_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_tensor_scaling
(
inputB
->
scaling_mode
))
{
void
*
A_scale_inverse
=
param
.
A_scale_inv
;
void
*
B_scale_inverse
=
param
.
B_scale_inv
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
...
...
@@ -363,7 +492,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_a
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F
;
#endif // CUBLAS_VERSION >= 120800
}
else
if
(
(
is_mxfp_scaling
(
inputA
->
scaling_mode
)
&&
is_mxfp_scaling
(
inputB
->
scaling_mode
))
)
{
}
else
if
(
mxfp8_gemm
)
{
#if CUBLAS_VERSION >= 120800
NVTE_CHECK
(
cublas_version
()
>=
120800
,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
cublas_version
());
...
...
@@ -388,6 +517,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#else
NVTE_ERROR
(
"MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif // CUBLAS_VERSION >= 120800
}
else
if
(
use_fp4
)
{
// NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK
(
cublas_version
()
>=
120800
,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
cublas_version
());
// make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE
cublasDataType_t
scale_type
=
CUDA_R_32F
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_SCALE_TYPE
,
&
scale_type
,
sizeof
(
scale_type
)));
// Set pointer mode: alpha and beta are both device pointers
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
cublasLtPointerMode_t
pointer_mode
=
CUBLASLT_POINTER_MODE_DEVICE
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
pointer_mode
,
sizeof
(
pointer_mode
)));
fp8e4m3
*
A_scale_inverse
=
reinterpret_cast
<
fp8e4m3
*>
(
param
.
A_scale_inv
);
fp8e4m3
*
B_scale_inverse
=
reinterpret_cast
<
fp8e4m3
*>
(
param
.
B_scale_inv
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
A_scale_inverse
,
sizeof
(
A_scale_inverse
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
scaling_mode_a
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3
;
scaling_mode_b
=
CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3
;
#else
NVTE_ERROR
(
"FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif // CUBLAS_VERSION >= 120800
}
else
if
((
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
&&
...
...
@@ -520,14 +677,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
#else
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is "
,
cuda
::
cudart_version
());
...
...
@@ -554,6 +708,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif
}
// align the workspace to 256 B
const
int
required_alignment
=
256
;
const
auto
original_workspace_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
workspace
));
uint8_t
*
aligned_workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
)
+
required_alignment
-
original_workspace_alignment
;
workspaceSize
=
workspaceSize
-
required_alignment
+
original_workspace_alignment
;
const
auto
new_workspace_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
aligned_workspace_ptr
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceCreate
(
&
preference
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
)));
...
...
@@ -561,7 +723,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const
auto
B_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
param
.
B
));
const
auto
C_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
C
));
const
auto
D_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
D
));
const
auto
workspace_alignment
=
_getAlignment
(
reinterpret_cast
<
uintptr_t
>
(
workspace
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES
,
&
A_alignment
,
sizeof
(
A_alignment
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
...
...
@@ -570,8 +731,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES
,
&
C_alignment
,
sizeof
(
C_alignment
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
preference
,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES
,
&
D_alignment
,
sizeof
(
D_alignment
)));
NVTE_CHECK
(
workspace_alignment
%
256
==
0
,
"cuBLAS workspace pointer must be aligned to 256 bytes, got "
,
workspace_alignment
);
NVTE_CHECK
(
new_workspace_alignment
%
256
==
0
,
"cuBLAS workspace pointer must be aligned to 256 bytes, got "
,
new_workspace_alignment
);
const
auto
status
=
cublasLtMatmulAlgoGetHeuristic
(
handle
,
operationDesc
,
Adesc
,
Bdesc
,
Cdesc
,
Ddesc
,
preference
,
...
...
@@ -582,16 +744,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
if
(
returnedResults
==
0
)
NVTE_ERROR
(
"Unable to find any suitable algorithms"
);
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
alpha
),
/* alpha */
param
.
A
,
/* A */
Adesc
,
param
.
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
C
,
/* C */
Cdesc
,
D
,
/* D */
Ddesc
,
&
heuristicResult
.
algo
,
/* algo */
workspace
,
/* workspace */
workspaceSize
,
stream
));
/* stream */
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
operationDesc
,
alpha
,
/* alpha */
param
.
A
,
/* A */
Adesc
,
param
.
B
,
/* B */
Bdesc
,
beta
,
/* beta */
C
,
/* C */
Cdesc
,
D
,
/* D */
Ddesc
,
&
heuristicResult
.
algo
,
/* algo */
aligned_workspace_ptr
,
/* workspace */
workspaceSize
,
stream
));
/* stream */
// Update FP8 scale-inv in output tensor
// Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
...
...
@@ -666,13 +827,26 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm
);
using
namespace
transformer_engine
;
// Tensors
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
Tensor
*
outputD
=
convertNVTETensor
Check
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
// Scales
const
float
alpha
=
1
;
const
float
beta
=
accumulate
?
1
:
0
;
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if
(
is_nvfp_scaling
(
inputA
->
scaling_mode
)
||
is_nvfp_scaling
(
inputB
->
scaling_mode
))
{
NVTE_ERROR
(
"nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."
);
}
#ifdef __HIP_PLATFORM_AMD__
const
size_t
A0
=
inputA
->
flat_first_dim
();
const
size_t
A1
=
inputA
->
flat_last_dim
();
...
...
@@ -734,9 +908,135 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif //__HIP_PLATFORM_AMD__
&
alpha
,
&
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
void
nvte_cublas_gemm_v2
(
int
transa
,
int
transb
,
const
float
*
alpha
,
const
NVTETensor
A
,
const
NVTETensor
B
,
const
float
*
beta
,
const
NVTETensor
C
,
NVTETensor
D
,
NVTETensor
workspace
,
NVTEMatmulConfig
config
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm_v2
);
using
namespace
transformer_engine
;
// Data tensors
const
Tensor
*
A_tensor
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
B_tensor
=
convertNVTETensorCheck
(
B
);
const
Tensor
*
C_tensor
=
convertNVTETensorCheck
(
C
);
Tensor
*
D_tensor
=
convertNVTETensorCheck
(
D
);
NVTE_CHECK
(
C_tensor
==
D_tensor
,
"Currently nvte_cublas_gemm_v2 does not support different C and D tensors."
);
// Workspace
void
*
workspace_ptr
=
nullptr
;
size_t
workspace_size
=
0
;
Tensor
*
workspace_tensor
=
convertNVTETensor
(
workspace
);
if
(
workspace_tensor
!=
nullptr
)
{
workspace_ptr
=
workspace_tensor
->
data
.
dptr
;
workspace_size
=
get_buffer_size_bytes
(
workspace_tensor
->
data
.
numel
(),
workspace_tensor
->
data
.
dtype
);
}
// Additional config
MatmulConfig
config_
;
if
(
config
!=
nullptr
)
{
config_
=
*
reinterpret_cast
<
MatmulConfig
*>
(
config
);
}
// Configure GEMM epilogue
const
bool
with_grad_epilogue
=
(
config_
.
dbias_tensor
!=
nullptr
||
config_
.
with_dgelu_epilogue
);
if
(
with_grad_epilogue
)
{
NVTE_CHECK
(
config_
.
bias_tensor
==
nullptr
&&
!
config_
.
with_gelu_epilogue
,
"Invalid epilogue (bias="
,
config_
.
bias_tensor
!=
nullptr
,
", dbias="
,
config_
.
dbias_tensor
!=
nullptr
,
", gelu="
,
config_
.
with_gelu_epilogue
,
", dgelu="
,
config_
.
with_dgelu_epilogue
,
")."
);
}
Tensor
dummy_tensor
;
Tensor
*
epilogue_bias_tensor
=
&
dummy_tensor
;
if
(
!
with_grad_epilogue
&&
config_
.
bias_tensor
!=
nullptr
)
{
epilogue_bias_tensor
=
convertNVTETensorCheck
(
config_
.
bias_tensor
);
}
else
if
(
with_grad_epilogue
&&
config_
.
dbias_tensor
!=
nullptr
)
{
epilogue_bias_tensor
=
convertNVTETensorCheck
(
config_
.
dbias_tensor
);
}
Tensor
*
epilogue_aux_tensor
=
&
dummy_tensor
;
if
(
config_
.
with_gelu_epilogue
||
config_
.
with_dgelu_epilogue
)
{
NVTE_CHECK
(
config_
.
epilogue_aux_tensor
!=
nullptr
,
"Requested epilogue (bias="
,
config_
.
bias_tensor
!=
nullptr
,
", dbias="
,
config_
.
dbias_tensor
!=
nullptr
,
", gelu="
,
config_
.
with_gelu_epilogue
,
", dgelu="
,
config_
.
with_dgelu_epilogue
,
") without providing aux tensor."
);
epilogue_aux_tensor
=
convertNVTETensor
(
config_
.
epilogue_aux_tensor
);
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
*
alpha
==
1.0
f
,
"alpha must be 1.0 for hip"
);
NVTE_CHECK
(
*
beta
==
1.0
f
||
*
beta
==
0.0
f
,
"beta must be 1.0 or 0.0 for hip"
);
bool
accumulate
=
false
;
if
(
*
alpha
==
1.0
f
and
*
beta
==
1.0
f
)
{
accumulate
=
true
;
}
const
size_t
A0
=
A_tensor
->
flat_first_dim
();
const
size_t
A1
=
A_tensor
->
flat_last_dim
();
const
size_t
B0
=
B_tensor
->
flat_first_dim
();
const
size_t
B1
=
B_tensor
->
flat_last_dim
();
const
int
m
=
transa
?
A0
:
A1
;
const
int
k
=
transa
?
A1
:
A0
;
const
int
n
=
transb
?
B1
:
B0
;
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
const
bool
use_int8
=
is_int8_dtype
(
A_tensor
->
data
.
dtype
)
||
is_int8_dtype
(
B_tensor
->
data
.
dtype
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
A_tensor
->
data
.
dtype
)
||
is_fp8_dtype
(
B_tensor
->
data
.
dtype
);
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
config_
.
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
((
epilogue_bias_tensor
->
data
.
dptr
!=
nullptr
)
||
(
epilogue_aux_tensor
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
cublas_gemm
(
A_tensor
,
B_tensor
,
D_tensor
,
epilogue_bias_tensor
,
epilogue_aux_tensor
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
with_grad_epilogue
,
workspace_ptr
,
workspace_size
,
accumulate
,
config_
.
use_split_accumulator
,
config_
.
sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
}
else
{
hipblas_gemm
(
A_tensor
,
B_tensor
,
D_tensor
,
epilogue_bias_tensor
,
epilogue_aux_tensor
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
with_grad_epilogue
,
workspace_ptr
,
workspace_size
,
accumulate
,
config_
.
use_split_accumulator
,
config_
.
sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
}
#else
// Launch GEMM
cublas_gemm
(
A_tensor
,
B_tensor
,
D_tensor
,
epilogue_bias_tensor
,
epilogue_aux_tensor
,
transa
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transb
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
with_grad_epilogue
,
workspace_ptr
,
workspace_size
,
alpha
,
beta
,
config_
.
use_split_accumulator
,
config_
.
sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
...
...
@@ -745,13 +1045,21 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
)
{
NVTE_API_CALL
(
nvte_cublas_gemm_scaled
);
using
namespace
transformer_engine
;
// Tensors
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
Tensor
*
outputD
=
convertNVTETensor
Check
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
// Check for NVFP4
// TODO Remove once alpha scale logic is moved into cublas_gemm function
if
(
is_nvfp_scaling
(
inputA
->
scaling_mode
)
||
is_nvfp_scaling
(
inputB
->
scaling_mode
))
{
NVTE_ERROR
(
"nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."
);
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK
(
alpha
==
1.0
f
,
"alpha must be 1.0 for hip"
);
NVTE_CHECK
(
beta
==
1.0
f
||
beta
==
0.0
f
,
"beta must be 1.0 or 0.0 for hip"
);
...
...
@@ -820,7 +1128,7 @@ void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
alpha
,
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
&
alpha
,
&
beta
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
);
#endif
}
...
...
@@ -838,12 +1146,12 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR
(
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR
(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
NVTE_CHECK
(
transformer_engine
::
cuda
::
cudart_version
()
>=
12020
&&
transformer_engine
::
cuda
::
cudart_version
()
<
13000
,
...
...
@@ -854,7 +1162,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is "
,
cublas_version
());
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
#endif // __HIP_PLATFORM_AMD__
#ifdef NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
...
...
@@ -863,6 +1175,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const
Tensor
*
inputCounter
=
convertNVTETensor
(
counter
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
const
void
*
alpha_ptr
=
GetScalarOne
();
const
void
*
beta_ptr
=
accumulate
?
GetScalarOne
()
:
GetScalarZero
();
NVTE_CHECK
(
is_delayed_tensor_scaling
(
inputA
->
scaling_mode
)
&&
is_delayed_tensor_scaling
(
inputB
->
scaling_mode
),
"Atomic GEMM only supports delayed scaling."
);
...
...
@@ -917,9 +1232,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
(
transa
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
(
transb
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
1.0
f
,
(
accumulate
)
?
1.0
f
:
0.0
f
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
alpha_ptr
,
beta_ptr
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
#endif //__HIP_PLATFORM_AMD__
#endif // NVTE_CUBLAS_ATOMIC_GEMM_COMPILE
}
void
multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
...
...
@@ -948,17 +1264,59 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens
}
else
{
NVTE_FORCE_BLAS_MULSTREAM
=
false
;
}
if
(
NVTE_FORCE_BLAS_MULSTREAM
){
if
(
NVTE_FORCE_BLAS_MULSTREAM
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
detail
::
get_compute_stream
(
i
%
num_streams
));
// Check whether GELU or dGELU epilogue is requested
Tensor
*
pre_gelu_tensor
=
convertNVTETensor
(
pre_gelu_out
[
i
]);
bool
with_gelu_dgelu_epilogue
=
(
pre_gelu_tensor
!=
nullptr
&&
pre_gelu_tensor
->
data
.
dptr
!=
nullptr
);
// Construct config
MatmulConfig
config
;
if
(
grad
)
{
config
.
dbias_tensor
=
bias
[
i
];
config
.
with_dgelu_epilogue
=
with_gelu_dgelu_epilogue
;
}
else
{
config
.
bias_tensor
=
bias
[
i
];
config
.
with_gelu_epilogue
=
with_gelu_dgelu_epilogue
;
}
config
.
epilogue_aux_tensor
=
pre_gelu_out
[
i
];
config
.
use_split_accumulator
=
use_split_accumulator
;
config
.
sm_count
=
math_sm_count
;
// Launch GEMM
const
float
alpha
=
1.
f
;
const
float
beta
=
accumulate
?
1.
f
:
0.
f
;
nvte_cublas_gemm_v2
(
transa
,
transb
,
&
alpha
,
A
[
i
],
B
[
i
],
&
beta
,
D
[
i
],
D
[
i
],
workspace
[
i
%
num_streams
],
&
config
,
detail
::
get_compute_stream
(
i
%
num_streams
));
}
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
detail
::
get_compute_stream
(
i
%
num_streams
),
1
,
0
,
i
%
num_streams
);
// Check whether GELU or dGELU epilogue is requested
Tensor
*
pre_gelu_tensor
=
convertNVTETensor
(
pre_gelu_out
[
i
]);
bool
with_gelu_dgelu_epilogue
=
(
pre_gelu_tensor
!=
nullptr
&&
pre_gelu_tensor
->
data
.
dptr
!=
nullptr
);
// Construct config
MatmulConfig
config
;
if
(
grad
)
{
config
.
dbias_tensor
=
bias
[
i
];
config
.
with_dgelu_epilogue
=
with_gelu_dgelu_epilogue
;
}
else
{
config
.
bias_tensor
=
bias
[
i
];
config
.
with_gelu_epilogue
=
with_gelu_dgelu_epilogue
;
}
config
.
epilogue_aux_tensor
=
pre_gelu_out
[
i
];
config
.
use_split_accumulator
=
use_split_accumulator
;
config
.
sm_count
=
math_sm_count
;
// Launch GEMM
const
float
alpha
=
1.
f
;
const
float
beta
=
accumulate
?
1.
f
:
0.
f
;
nvte_cublas_gemm_v2
(
transa
,
transb
,
&
alpha
,
A
[
i
],
B
[
i
],
&
beta
,
D
[
i
],
D
[
i
],
workspace
[
i
%
num_streams
],
&
config
,
detail
::
get_compute_stream
(
i
%
num_streams
),
1
,
0
,
i
%
num_streams
);
}
}
...
...
transformer_engine/common/hadamard_transform/hadamard_transform.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
{
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
float
k16x16HadamardScale
=
0.25
f
;
template
<
bool
kTranspose
>
__device__
__forceinline__
void
ldmatrix_x4_m8n8_shared_b16
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
)
{
auto
smem_addr
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
addr
));
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"r"
(
smem_addr
));
}
else
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"r"
(
smem_addr
));
}
}
template
<
bool
kTranspose
>
__device__
__forceinline__
void
load_matrix_16x16_from_shared
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
,
uint32_t
stride
)
{
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"l"
(
addr
),
"r"
(
stride
));
}
else
{
asm
volatile
(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"l"
(
addr
),
"r"
(
stride
));
}
}
template
<
bool
kTranspose
>
__device__
__forceinline__
void
store_matrix_16x16_to_global
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
,
uint32_t
stride
)
{
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;
\n
"
:
:
"l"
(
addr
),
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
stride
));
}
else
{
asm
volatile
(
"wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;
\n
"
:
:
"l"
(
addr
),
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
stride
));
}
}
__device__
__forceinline__
void
matrix_transpose_m8_n8_b16_inplace
(
uint32_t
&
a0
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;
\n\t
"
:
"=r"
(
a0
)
:
"r"
(
a0
));
}
__device__
__forceinline__
void
unpack_max_of_packed_bf16
(
uint32_t
&
packed_bf16
,
float
&
float_dst
)
{
__nv_bfloat162
bf16x2
=
*
reinterpret_cast
<
__nv_bfloat162
*>
(
&
packed_bf16
);
float
f_a
=
__bfloat162float
(
bf16x2
.
x
);
float
f_b
=
__bfloat162float
(
bf16x2
.
y
);
asm
volatile
(
"max.xorsign.abs.f32 %0, %1, %2;
\n\t
"
:
"=f"
(
float_dst
)
:
"f"
(
f_a
),
"f"
(
f_b
));
float_dst
=
fabsf
(
float_dst
);
}
template
<
bool
kCalculateAmax
>
__device__
__forceinline__
void
mma_m16_n16_k16_b16_b16_b16_noacc
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
uint32_t
&
b0
,
uint32_t
&
b1
,
uint32_t
&
b2
,
uint32_t
&
b3
,
uint32_t
&
c0
,
uint32_t
&
c1
,
uint32_t
&
c2
,
uint32_t
&
c3
,
uint32_t
&
amax_result
)
{
uint32_t
zero
=
0
;
uint32_t
temp0
,
temp1
,
temp2
,
temp3
,
temp4
,
temp5
,
temp6
,
temp7
;
asm
volatile
(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32
\n
"
"{%0, %1, %2, %3, %4, %5, %6, %7},
\n
"
"{%8, %9, %10, %11},
\n
"
"{%12, %13, %14, %15},
\n
"
"{%16, %17, %18, %19, %20, %21, %22, %23};
\n\t
"
:
"=r"
(
temp0
),
"=r"
(
temp1
),
"=r"
(
temp2
),
"=r"
(
temp3
),
"=r"
(
temp4
),
"=r"
(
temp5
),
"=r"
(
temp6
),
"=r"
(
temp7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
b2
),
"r"
(
b3
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c0
)
:
"r"
(
temp1
),
"r"
(
temp0
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c1
)
:
"r"
(
temp3
),
"r"
(
temp2
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c2
)
:
"r"
(
temp5
),
"r"
(
temp4
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c3
)
:
"r"
(
temp7
),
"r"
(
temp6
));
if
constexpr
(
kCalculateAmax
)
{
uint32_t
max_even
;
uint32_t
max_odd
;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
max_even
)
:
"r"
(
c0
),
"r"
(
c2
));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
max_odd
)
:
"r"
(
c1
),
"r"
(
c3
));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
amax_result
)
:
"r"
(
max_even
),
"r"
(
max_odd
));
}
}
template
<
bool
kReturnIdentity
,
bool
kReturnTransposed
,
bool
kInverseHadamardIdentity
,
bool
kInverseHadamardTransposed
>
__device__
__forceinline__
void
get_hadamard_matrix_fragment
(
uint32_t
*
had_frag_i
,
uint16_t
random_sign_mask
,
uint32_t
*
had_frag_t
,
uint16_t
random_sign_mask_t
)
{
int32_t
tid
=
threadIdx
.
x
%
32
;
// Local tid
float
temp_i
[
2
];
float
temp_t
[
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t
r
=
i
*
8
+
tid
/
4
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
2
;
k
++
)
{
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t
c
=
j
*
8
+
k
+
tid
%
4
*
2
;
// 1 -> -1.0f, 0 -> 1.0f
int32_t
base_sign
=
__popc
(
r
&
c
);
if
constexpr
(
kReturnIdentity
)
{
int32_t
sign_i
;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if
constexpr
(
kInverseHadamardIdentity
)
{
sign_i
=
((
random_sign_mask
>>
r
)
^
base_sign
);
}
else
{
sign_i
=
((
random_sign_mask
>>
c
)
^
base_sign
);
}
temp_i
[
k
]
=
copysignf
(
k16x16HadamardScale
,
__int_as_float
(
sign_i
<<
31
));
}
if
constexpr
(
kReturnTransposed
)
{
int32_t
sign_t
;
if
constexpr
(
kInverseHadamardTransposed
)
{
sign_t
=
((
random_sign_mask_t
>>
r
)
^
base_sign
);
}
else
{
sign_t
=
((
random_sign_mask_t
>>
c
)
^
base_sign
);
}
temp_t
[
k
]
=
copysignf
(
k16x16HadamardScale
,
__int_as_float
(
sign_t
<<
31
));
}
}
if
constexpr
(
kReturnIdentity
)
{
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
had_frag_i
[
i
*
2
+
j
])
:
"f"
(
temp_i
[
1
]),
"f"
(
temp_i
[
0
]));
}
if
constexpr
(
kReturnTransposed
)
{
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
had_frag_t
[
i
*
2
+
j
])
:
"f"
(
temp_t
[
1
]),
"f"
(
temp_t
[
0
]));
}
}
}
}
__device__
__forceinline__
uint32_t
swizzle_128B_atom_32B
(
uint32_t
gmem_row_idx
,
uint32_t
gmem_col_idx
)
{
uint32_t
smem_row_idx
=
gmem_row_idx
;
uint32_t
xor_factor
=
(
smem_row_idx
*
2
)
%
8
;
uint32_t
smem_col_idx
=
gmem_col_idx
^
xor_factor
;
return
smem_row_idx
*
8
+
smem_col_idx
;
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ComputeKernel
(
uint32_t
b_frag_i
[
4
],
uint32_t
b_frag_t
[
4
],
IType
*
in_sh_ptr
,
uint32_t
&
local_pre_rht_amax_reg
,
uint32_t
&
local_amax_reg
,
uint32_t
&
local_amax_t_reg
)
{
uint32_t
a_frag
[
4
];
// A matrix fragment
uint32_t
c_frag
[
4
];
// Result fragment
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
local_rank
=
(
threadIdx
.
x
%
kThreadsPerWarp
);
int
ld_row_idx
=
local_rank
%
kHadamardDimension
;
int
ld_col_idx
=
local_rank
/
kHadamardDimension
+
warp_id
*
2
;
int
swizzle_idx
=
swizzle_128B_atom_32B
(
ld_row_idx
,
ld_col_idx
);
uint32_t
temp_amax_reg
;
uint32_t
temp_amax_t_reg
;
if
(
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnIdentityAmax
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
b_frag_i
[
0
],
b_frag_i
[
1
],
b_frag_i
[
2
],
b_frag_i
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_reg
)
:
"r"
(
local_amax_reg
),
"r"
(
temp_amax_reg
));
}
if
(
kReturnTransposedAmax
)
{
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if
(
!
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
0
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
1
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
2
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
3
]);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnTransposedAmax
>
(
a_frag
[
0
],
a_frag
[
2
],
a_frag
[
1
],
a_frag
[
3
],
b_frag_t
[
0
],
b_frag_t
[
1
],
b_frag_t
[
2
],
b_frag_t
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_t_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_t_reg
)
:
"r"
(
local_amax_t_reg
),
"r"
(
temp_amax_t_reg
));
}
if
(
kReturnPreRhtAmax
)
{
if
(
!
kReturnIdentityAmax
&&
!
kReturnTransposedAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
1
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
2
])
:
"r"
(
a_frag
[
2
]),
"r"
(
a_frag
[
3
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
2
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_pre_rht_amax_reg
)
:
"r"
(
a_frag
[
0
]),
"r"
(
local_pre_rht_amax_reg
));
}
}
template
<
int
kN
>
__device__
__host__
constexpr
int
NextPowerOf2
()
{
static_assert
(
kN
>
0
,
"kN must be > 0"
);
// Round up to the next power of 2 by counting leading zeros.
return
1
<<
(
32
-
__builtin_clz
(
kN
-
1
));
}
template
<
int
kNumWarps
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ReduceMax
(
const
float
pre_rht_amax
,
const
float
identity_amax
,
const
float
transpose_amax
,
float
*
staging_for_pre_rht
,
float
*
staging_for_identity
,
float
*
staging_for_transpose
,
float
*
output_pre_rht_amax_ptr
,
float
*
output_identity_amax_ptr
,
float
*
output_transpose_amax_ptr
,
const
int
warpid
)
{
// intra-warp reduction
constexpr
int
kWarpSize
=
32
;
int
local_rank
=
threadIdx
.
x
%
32
;
float
warp_pre_rht_amax
=
kReturnPreRhtAmax
?
warp_reduce_max
<
kWarpSize
>
(
pre_rht_amax
)
:
0.0
f
;
float
warp_identity_amax
=
kReturnIdentityAmax
?
warp_reduce_max
<
kWarpSize
>
(
identity_amax
)
:
0.0
f
;
float
warp_transpose_amax
=
kReturnTransposedAmax
?
warp_reduce_max
<
kWarpSize
>
(
transpose_amax
)
:
0.0
f
;
// inter-warp reduction
if
(
threadIdx
.
x
%
32
==
0
)
{
if
(
kReturnPreRhtAmax
)
{
staging_for_pre_rht
[
warpid
]
=
warp_pre_rht_amax
;
}
if
(
kReturnIdentityAmax
)
{
staging_for_identity
[
warpid
]
=
warp_identity_amax
;
}
if
(
kReturnTransposedAmax
)
{
staging_for_transpose
[
warpid
]
=
warp_transpose_amax
;
}
}
__syncthreads
();
constexpr
int
kNumWarpsPow2
=
NextPowerOf2
<
kNumWarps
>
();
if
(
warpid
==
0
)
{
if
(
kReturnIdentityAmax
)
{
float
identity_accum
=
local_rank
<
kNumWarps
?
staging_for_identity
[
local_rank
]
:
0.0
f
;
identity_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
identity_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_identity_amax_ptr
,
identity_accum
);
}
}
}
if
(
warpid
==
1
)
{
if
(
kReturnTransposedAmax
)
{
float
transpose_accum
=
local_rank
<
kNumWarps
?
staging_for_transpose
[
local_rank
]
:
0.0
f
;
transpose_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
transpose_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_transpose_amax_ptr
,
transpose_accum
);
}
}
}
if
(
warpid
==
2
)
{
if
(
kReturnPreRhtAmax
)
{
float
pre_rht_accum
=
local_rank
<
kNumWarps
?
staging_for_pre_rht
[
local_rank
]
:
0.0
f
;
pre_rht_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
pre_rht_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_pre_rht_amax_ptr
,
pre_rht_accum
);
}
}
}
}
__launch_bounds__
(
1
)
__global__
void
ZeroAmaxKernel
(
float
*
__restrict__
output_pre_rht_amax_ptr
,
float
*
__restrict__
output_identity_amax_ptr
,
float
*
__restrict__
output_transpose_amax_ptr
)
{
if
(
output_pre_rht_amax_ptr
!=
nullptr
)
{
*
output_pre_rht_amax_ptr
=
0
;
}
if
(
output_identity_amax_ptr
!=
nullptr
)
{
*
output_identity_amax_ptr
=
0
;
}
if
(
output_transpose_amax_ptr
!=
nullptr
)
{
*
output_transpose_amax_ptr
=
0
;
}
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
CHUNK_DIM_Y
,
int
CHUNK_DIM_X
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
int
THREADS_PER_CHUNK
,
int
THREADS_PER_Y
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__global__
void
HadamardAmaxTmaKernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
float
*
__restrict__
output_pre_rht_amax_ptr
,
float
*
__restrict__
output_identity_amax_ptr
,
float
*
__restrict__
output_transpose_amax_ptr
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
uint64_t
num_rows
,
uint64_t
row_length
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
&&
CHUNK_DIM_Y
%
BUFF_DIM_Y
==
0
);
static_assert
(
CHUNK_DIM_X
>=
BUFF_DIM_X
&&
CHUNK_DIM_X
%
BUFF_DIM_X
==
0
);
constexpr
size_t
STAGES_Y
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
STAGES_X
=
CHUNK_DIM_X
/
BUFF_DIM_X
;
constexpr
int
kNumWarps
=
(
THREADS_PER_CHUNK
*
THREADS_PER_Y
)
/
kThreadsPerWarp
;
const
int
input_block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
input_block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
extern
__shared__
__align__
(
128
)
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t
*
dshmem
=
reinterpret_cast
<
uint8_t
*>
((
base_shmem_ptr
+
127
)
&
~
127ULL
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr
size_t
in_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
IType
*
in_sh_0
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_sh_1
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_shs
[
2
]
=
{
in_sh_0
,
in_sh_1
};
constexpr
int
shmem_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t
*
mbar
=
reinterpret_cast
<
uint64_t
*>
(
dshmem
);
dshmem
+=
sizeof
(
uint64_t
)
*
(
STAGES_X
*
STAGES_Y
);
float
*
max_staging_identity
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_transpose
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_pre_rht
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
initialize_barriers
<
STAGES_X
*
STAGES_Y
,
THREADS_PER_CHUNK
*
THREADS_PER_Y
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
in_shs
[
0
],
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_block_offset_X
,
input_block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
uint32_t
had_frag_i
[
4
];
uint32_t
had_frag_t
[
4
];
get_hadamard_matrix_fragment
<
kReturnIdentityAmax
,
kReturnTransposedAmax
,
false
,
false
>
(
had_frag_i
,
random_sign_mask
,
had_frag_t
,
random_sign_mask_t
);
float
local_pre_rht_amax
=
0.0
;
float
local_amax
=
0.0
;
float
local_amax_t
=
0.0
;
uint32_t
local_pre_rht_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_pre_rht_amax
);
uint32_t
local_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax
);
uint32_t
local_amax_t_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax_t
);
for
(
int
stage_y
=
0
;
stage_y
<
STAGES_Y
;
++
stage_y
)
{
for
(
int
stage_x
=
0
;
stage_x
<
STAGES_X
;
++
stage_x
)
{
int
stage
=
STAGES_X
*
stage_y
+
stage_x
;
const
int
next_stage
=
stage
+
1
;
const
int
next_stage_x
=
stage_x
+
1
==
STAGES_X
?
0
:
stage_x
+
1
;
const
int
next_stage_y
=
stage_x
+
1
==
STAGES_X
?
stage_y
+
1
:
stage_y
;
if
(
next_stage
<
STAGES_X
*
STAGES_Y
)
{
const
int
input_global_offset_Y
=
input_block_offset_Y
+
next_stage_y
*
BUFF_DIM_Y
;
const
int
input_global_offset_X
=
input_block_offset_X
+
next_stage_x
*
BUFF_DIM_X
;
copy_2d_to_shared
(
in_shs
[
next_stage
%
2
],
// ping-pong
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_global_offset_X
,
input_global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
const
size_t
compute_stage_x_num
=
BUFF_DIM_X
/
(
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
));
const
size_t
compute_stage_y_num
=
BUFF_DIM_Y
/
(
kHadamardDimension
*
THREADS_PER_Y
);
const
size_t
in_row_stride
=
BUFF_DIM_X
;
IType
*
in_sh_ptr
=
in_shs
[
stage
%
2
];
#pragma unroll
for
(
size_t
compute_stage_y
=
0
;
compute_stage_y
<
compute_stage_y_num
;
compute_stage_y
++
)
{
const
int
row_idx_offset
=
(
compute_stage_y
*
kHadamardDimension
*
THREADS_PER_Y
+
threadIdx
.
y
*
kHadamardDimension
);
const
int
in_row_offset
=
row_idx_offset
*
in_row_stride
;
#pragma unroll
for
(
size_t
compute_stage_x
=
0
;
compute_stage_x
<
compute_stage_x_num
;
compute_stage_x
++
)
{
ComputeKernel
<
IType
,
kHadamardDimension
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
had_frag_i
,
had_frag_t
,
in_sh_ptr
+
in_row_offset
+
(
compute_stage_x
*
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
)),
local_pre_rht_amax_reg
,
local_amax_reg
,
local_amax_t_reg
);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads
();
}
}
}
const
int
warpid
=
(
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
)
/
kThreadsPerWarp
;
if
constexpr
(
kReturnPreRhtAmax
)
{
unpack_max_of_packed_bf16
(
local_pre_rht_amax_reg
,
local_pre_rht_amax
);
}
if
constexpr
(
kReturnIdentityAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_reg
,
local_amax
);
}
if
constexpr
(
kReturnTransposedAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_t_reg
,
local_amax_t
);
}
ReduceMax
<
kNumWarps
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
local_pre_rht_amax
,
local_amax
,
local_amax_t
,
max_staging_pre_rht
,
max_staging_identity
,
max_staging_transpose
,
output_pre_rht_amax_ptr
,
output_identity_amax_ptr
,
output_transpose_amax_ptr
,
warpid
);
destroy_barriers
<
STAGES_X
*
STAGES_Y
>
(
mbar
,
is_master_thread
);
#else
NVTE_DEVICE_ERROR
(
"Kernel is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template
<
typename
T
,
int
kHadamardDimension
,
bool
kComputeIdentity
,
bool
kComputeTransposed
,
bool
kReturnIdentity
,
bool
kReturnTransposed
,
bool
kUpdateIdentityAmax
,
bool
kUpdateTransposeAmax
,
bool
kOutputTrueTransposed
>
__global__
void
HadamardTransformKernel
(
const
T
*
__restrict__
input
,
T
*
__restrict__
output
,
T
*
__restrict__
output_t
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
uint64_t
num_input_rows
,
uint64_t
num_input_cols
,
float
*
__restrict__
amax
,
float
*
__restrict__
amax_t
,
bool
inverse_hadamard
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
static_assert
(
kHadamardDimension
==
16
,
"Currently only hadamard dimension 16 is supported."
);
// The whole threadblock will share the same smem.
extern
__shared__
__align__
(
16
)
T
smem
[];
// Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16.
// If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices.
int32_t
tid
=
threadIdx
.
x
;
int32_t
warp_id
=
threadIdx
.
y
*
blockDim
.
z
+
threadIdx
.
z
;
int32_t
local_bx
=
threadIdx
.
y
;
int32_t
local_by
=
threadIdx
.
z
;
// Define the register fragments
uint32_t
a_frag
[
4
];
// A matrix fragment
uint32_t
b_frag_i
[
4
];
// Transposed Hadamard matrix fragment, used for A @ B(col major)
uint32_t
b_frag_t
[
4
];
// Hadamard matrix fragment, used for A.T @ B.T(col major)
uint32_t
c_frag
[
4
];
// Result fragment
// row and col for each thread. 32 threads will work together in 128 chunk to
// load the data from global memory to shared memory.
uint32_t
row
=
tid
/
(
kHadamardDimension
*
sizeof
(
T
)
/
sizeof
(
uint4
));
uint32_t
col
=
tid
%
(
kHadamardDimension
*
sizeof
(
T
)
/
sizeof
(
uint4
));
uint32_t
smem_index
=
tid
;
uint32_t
input_start_col
=
(
blockIdx
.
x
*
blockDim
.
y
+
local_bx
)
*
kHadamardDimension
;
uint32_t
input_start_row
=
(
blockIdx
.
y
*
blockDim
.
z
+
local_by
)
*
kHadamardDimension
;
bool
load
=
(
input_start_col
<
num_input_cols
)
&&
(
input_start_row
<
num_input_rows
);
if
(
!
load
)
{
// Out of bound, we are returning early. No thread divergence since the whole warp
// will return early.
return
;
}
uint64_t
global_offset
=
input_start_col
+
input_start_row
*
num_input_cols
;
uint64_t
global_offset_t
=
kOutputTrueTransposed
?
(
input_start_row
+
input_start_col
*
num_input_rows
)
:
global_offset
;
T
*
base_smem
=
smem
+
kHadamardDimension
*
kHadamardDimension
*
warp_id
;
uint32_t
*
smem_b32
=
reinterpret_cast
<
uint32_t
*>
(
base_smem
);
uint4
*
smem_b128
=
reinterpret_cast
<
uint4
*>
(
base_smem
);
// Asynchronously load the data from global memory to shared memory.
const
uint4
*
input_b128
=
reinterpret_cast
<
const
uint4
*>
(
input
+
global_offset
);
// Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each
// 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4
// to load the data in the tensor core swizzled format.
__pipeline_memcpy_async
(
&
smem_b128
[
smem_index
],
&
input_b128
[
row
*
num_input_cols
/
(
sizeof
(
uint4
)
/
sizeof
(
T
))
+
col
],
sizeof
(
uint4
));
__pipeline_commit
();
// Commit the memcpy. Wait when we are in the computation.
if
(
inverse_hadamard
)
{
get_hadamard_matrix_fragment
<
kComputeIdentity
,
kComputeTransposed
,
/*kInverseHadamard=*/
true
,
/*kInverseHadamardTransposed=*/
true
>
(
b_frag_i
,
random_sign_mask
,
b_frag_t
,
random_sign_mask_t
);
}
else
{
get_hadamard_matrix_fragment
<
kComputeIdentity
,
kComputeTransposed
,
/*kInverseHadamard=*/
false
,
/*kInverseHadamardTransposed=*/
false
>
(
b_frag_i
,
random_sign_mask
,
b_frag_t
,
random_sign_mask_t
);
}
float
local_amax
=
0.0
;
float
local_amax_t
=
0.0
;
uint32_t
local_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax
);
uint32_t
local_amax_t_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax_t
);
__pipeline_wait_prior
(
0
);
__syncwarp
();
// ensure all lanes finished their cp.async before reading smem
// Load the A to a_frag.
if
constexpr
(
kComputeIdentity
)
{
load_matrix_16x16_from_shared
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
smem_b32
,
kHadamardDimension
);
// 16x16 @ 16x16 leveraging all threads in the warp.
mma_m16_n16_k16_b16_b16_b16_noacc
<
kUpdateIdentityAmax
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
b_frag_i
[
0
],
b_frag_i
[
1
],
b_frag_i
[
2
],
b_frag_i
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
local_amax_reg
);
// Store the result to the shared memory in non-transposed order.
if
constexpr
(
kReturnIdentity
)
{
uint4
*
output_b128
=
reinterpret_cast
<
uint4
*>
(
output
+
global_offset
);
store_matrix_16x16_to_global
<
false
>
(
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
output_b128
,
num_input_cols
);
}
}
if
constexpr
(
kComputeTransposed
)
{
if
(
kComputeIdentity
)
{
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
0
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
1
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
2
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
3
]);
}
else
{
load_matrix_16x16_from_shared
<
true
>
(
a_frag
[
0
],
a_frag
[
2
],
// NOTE: intentional index swapping
a_frag
[
1
],
// NOTE: intentional index swapping
a_frag
[
3
],
smem_b32
,
kHadamardDimension
);
}
mma_m16_n16_k16_b16_b16_b16_noacc
<
kUpdateTransposeAmax
>
(
a_frag
[
0
],
// 2,1 is used if we are using movmatrix instruction.
// Thus loading the matrix in 2,1 order will just be normal.
// This is to be compatible with the movmatrix instruction.
a_frag
[
2
],
// NOTE: intentional index swapping for transpose purpose.
a_frag
[
1
],
// NOTE: intentional index swapping for transpose purpose.
a_frag
[
3
],
b_frag_t
[
0
],
b_frag_t
[
1
],
b_frag_t
[
2
],
b_frag_t
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
local_amax_t_reg
);
// Store the result to the shared memory in non-transposed order.
if
constexpr
(
kReturnTransposed
)
{
uint4
*
output_t_b128
=
reinterpret_cast
<
uint4
*>
(
output_t
+
global_offset_t
);
store_matrix_16x16_to_global
<!
kOutputTrueTransposed
>
(
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
output_t_b128
,
kOutputTrueTransposed
?
num_input_rows
:
num_input_cols
);
}
}
if
constexpr
(
kUpdateIdentityAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_reg
,
local_amax
);
local_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
local_amax
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
local_amax
=
__shfl_sync
(
0xFFFFFFFF
,
local_amax
,
lane_zero
);
// atomic CAS to output memory.
if
(
tid
%
kThreadsPerWarp
==
0
)
{
atomicMaxFloat
(
amax
,
local_amax
);
}
}
if
constexpr
(
kUpdateTransposeAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_t_reg
,
local_amax_t
);
local_amax_t
=
warp_reduce_max
<
kThreadsPerWarp
>
(
local_amax_t
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
local_amax_t
=
__shfl_sync
(
0xFFFFFFFF
,
local_amax_t
,
lane_zero
);
// atomic CAS to output memory.
if
(
tid
%
kThreadsPerWarp
==
0
)
{
atomicMaxFloat
(
amax_t
,
local_amax_t
);
}
}
#else
NVTE_DEVICE_ERROR
(
"Kernel is only supported on SM 9.0+."
);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
}
}
// namespace
void
hadamard_transform
(
const
Tensor
&
input_
,
Tensor
&
output_
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform
);
// Check tensors
// NOTE (frsun): This is non-intuitive, we are writing the result of
// transposed RHT to the output of rowwise.
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
NVTE_CHECK
(
output_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor must be simple tensor, but scaling mode is "
,
to_string
(
output_
.
scaling_mode
),
"."
);
const
SimpleTensor
&
input
=
input_
.
data
;
SimpleTensor
output
;
SimpleTensor
&
output_t
=
output_
.
data
;
// Check requested outputs
const
bool
return_identity
=
output
.
dptr
!=
nullptr
;
const
bool
return_transposed
=
output_t
.
dptr
!=
nullptr
;
if
(
!
return_identity
&&
!
return_transposed
)
{
// Nothing to do/ill-defined behavior.
return
;
}
checkCuDriverContext
(
stream
);
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
row_length
=
input
.
shape
[
ndim
-
1
];
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
num_rows
*=
input
.
shape
[
i
];
}
using
IType
=
bf16
;
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
row_length
%
kHadamardDimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
num_rows
%
kHadamardDimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
constexpr
uint64_t
kThreadBlockX
=
4
;
// Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth.
constexpr
uint64_t
kThreadBlockY
=
4
;
uint64_t
kNumWarpsPerSM
=
kThreadBlockX
*
kThreadBlockY
;
// The shared memory number of bytes required for **the whole threadblock**.
size_t
shmem_bytes
=
kHadamardDimension
*
kHadamardDimension
*
sizeof
(
IType
)
*
kNumWarpsPerSM
;
dim3
block
(
kThreadsPerWarp
,
kThreadBlockX
,
kThreadBlockY
);
dim3
grid
(
DIVUP
(
row_length
/
kHadamardDimension
,
kThreadBlockX
),
DIVUP
(
num_rows
/
kHadamardDimension
,
kThreadBlockY
));
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transposed
,
kReturnTransposed
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity
,
kReturnIdentity
,
auto
kernel
=
HadamardTransformKernel
<
IType
,
kHadamardDimension
,
kReturnIdentity
,
kReturnTransposed
,
kReturnIdentity
,
kReturnTransposed
,
false
,
false
,
true
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_bytes
);
kernel
<<<
grid
,
block
,
shmem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
IType
*>
(
input
.
dptr
),
reinterpret_cast
<
IType
*>
(
output
.
dptr
),
reinterpret_cast
<
IType
*>
(
output_t
.
dptr
),
random_sign_mask
,
random_sign_mask_t
,
num_rows
,
row_length
,
nullptr
,
nullptr
,
false
);););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then
// get the absolute max value of the result.
void
hadamard_transform_amax
(
const
Tensor
&
input_
,
Tensor
&
output_
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform_amax
);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
const
SimpleTensor
&
input
=
input_
.
data
;
// Check amax tensors
SimpleTensor
&
output_pre_rht_amax
=
output_
.
amax
;
SimpleTensor
output_identity_amax
;
SimpleTensor
&
output_transpose_amax
=
output_
.
columnwise_amax
;
// Check requested outputs
const
bool
return_pre_rht_amax
=
output_pre_rht_amax
.
dptr
!=
nullptr
;
const
bool
return_identity_amax
=
output_identity_amax
.
dptr
!=
nullptr
;
const
bool
return_transposed_amax
=
output_transpose_amax
.
dptr
!=
nullptr
;
if
(
!
return_identity_amax
&&
!
return_transposed_amax
&&
!
return_pre_rht_amax
)
{
// Nothing to do/ill-defined behavior.
return
;
}
// Zero out amaxes if needed
ZeroAmaxKernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output_pre_rht_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_identity_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_transpose_amax
.
dptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
checkCuDriverContext
(
stream
);
using
IType
=
bf16
;
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
row_length
=
input
.
shape
[
ndim
-
1
];
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
num_rows
*=
input
.
shape
[
i
];
}
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
row_length
%
kHadamardDimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
num_rows
%
kHadamardDimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
constexpr
uint64_t
kChunkBlockXSmall
=
128
;
constexpr
uint64_t
kChunkBlockYSmall
=
128
;
constexpr
uint64_t
kBuffDimX
=
64
;
constexpr
uint64_t
kBuffDimY
=
64
;
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
create_2D_tensor_map
(
/*tensorMap=*/
tensor_map_input
,
/*tensor=*/
input
,
/*globalY=*/
num_rows
,
/*globalX=*/
row_length
,
/*shmemY=*/
kBuffDimY
,
/*shmemX=*/
kBuffDimX
,
/*stride_elems=*/
row_length
,
/*offset_elems=*/
0
,
/*type_num_bits=*/
sizeof
(
IType
)
*
8
,
/*swizzle=*/
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B
);
constexpr
uint64_t
kThreadBlockX
=
4
;
constexpr
uint64_t
kThreadBlockY
=
1
;
constexpr
uint64_t
kNumWarps
=
kThreadBlockX
*
kThreadBlockY
;
dim3
block
(
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
);
dim3
grid
(
DIVUP
(
row_length
,
kChunkBlockXSmall
),
DIVUP
(
num_rows
,
kChunkBlockYSmall
));
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transposed_amax
,
kReturnTransposedAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity_amax
,
kReturnIdentityAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_pre_rht_amax
,
kReturnPreRhtAmax
,
// *2 for ping-pong
size_t
in_sh_size
=
kBuffDimX
*
kBuffDimY
*
2
*
sizeof
(
IType
);
size_t
mbar_size
=
sizeof
(
uint64_t
)
*
(
kChunkBlockXSmall
/
kBuffDimX
)
*
(
kChunkBlockYSmall
/
kBuffDimY
);
size_t
shmem_bytes
=
in_sh_size
+
mbar_size
+
kNumWarps
*
sizeof
(
float
)
*
3
;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes
=
(
shmem_bytes
+
128
);
auto
kernel
=
HadamardAmaxTmaKernel
<
IType
,
kHadamardDimension
,
kChunkBlockYSmall
,
kChunkBlockXSmall
,
kBuffDimY
,
kBuffDimX
,
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_bytes
);
kernel
<<<
grid
,
block
,
shmem_bytes
,
stream
>>>
(
tensor_map_input
,
reinterpret_cast
<
float
*>
(
output_pre_rht_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_identity_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_transpose_amax
.
dptr
),
random_sign_mask
,
random_sign_mask_t
,
num_rows
,
row_length
);)));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace transformer_engine
void
nvte_hadamard_transform
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform
);
using
namespace
transformer_engine
;
hadamard_transform
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
static_cast
<
uint16_t
>
(
random_sign_mask
),
static_cast
<
uint16_t
>
(
random_sign_mask_t
),
stream
);
}
void
nvte_hadamard_transform_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform_amax
);
using
namespace
transformer_engine
;
hadamard_transform_amax
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
static_cast
<
uint16_t
>
(
random_sign_mask
),
static_cast
<
uint16_t
>
(
random_sign_mask_t
),
stream
);
}
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
namespace
transformer_engine
{
namespace
detail
{
namespace
{
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
using
namespace
cute
;
using
cute
::
Tensor
;
// Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
// calculate the global encode scale factor for a given global amax.
__device__
__forceinline__
float
ComputeGlobalEncodeScaleFP4
(
const
float
global_amax
)
{
constexpr
float
kFP8E4M3Max
=
448.0
f
;
constexpr
float
kFP4E2M1Max
=
6.0
f
;
// If scale is infinity, return max value of float32
float
global_encode_scale
=
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
kFP8E4M3Max
*
kFP4E2M1Max
/
global_amax
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
());
// If global amax is 0 or infinity, return 1
return
(
global_amax
==
0.
f
||
global_encode_scale
==
0.
f
)
?
1.
f
:
global_encode_scale
;
}
template
<
class
ElementA
,
class
ElementB
,
class
ASmemLayout
,
class
BSmemLayout
>
struct
SharedStorage
{
static
constexpr
int
AccumulatorPipelineStageCount
=
16
;
using
AtomThrShapeMNK
=
cute
::
Shape
<
_1
,
_1
,
_1
>
;
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
4
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineStorage
=
typename
AccumulatorPipeline
::
SharedStorage
;
static
constexpr
int
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
using
MainloopPipeline
=
cutlass
::
PipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
Shape
<
_1
,
_1
,
_1
>
,
AtomThrShapeMNK
>
;
using
MainloopPipelineStorage
=
typename
MainloopPipeline
::
SharedStorage
;
alignas
(
16
)
AccumulatorPipelineStorage
accumulator
;
alignas
(
16
)
MainloopPipelineStorage
mainloop
;
alignas
(
16
)
cute
::
uint64_t
tma_barrier
[
1
];
uint32_t
tmem_base_ptr
;
struct
TensorStorage
:
cute
::
aligned_struct
<
128
,
_1
>
{
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute
::
array_aligned
<
ElementA
,
cute
::
cosize_v
<
ASmemLayout
>>
smem_A
;
cute
::
array_aligned
<
ElementB
,
cute
::
cosize_v
<
BSmemLayout
>>
smem_B
;
}
tensors
;
};
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
asm
volatile
(
\
"{
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;
\n
"
\
"}"
\
:
"=h"
(
output_ptr
[
0
]),
"=h"
(
output_ptr
[
1
])
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
output
;
}
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
StochasticNumericConverter
(
cutlass
::
Array
<
float
,
16
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
4
>
const
*
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
;
result_type
output
;
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*
result_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*>
(
&
output
);
cutlass
::
Array
<
float
,
8
>
const
*
source_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
float
,
8
>
const
*>
(
&
input
);
cutlass
::
Array
<
uint32_t
,
2
>
const
*
rbits_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
2
>
const
*>
(
rbits
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
result_ptr
[
i
]
=
StochasticNumericConverterBase
(
source_ptr
[
i
],
rbits_ptr
[
i
]);
}
return
output
;
}
template
<
class
MShape
,
class
NShape
,
class
KShape
,
class
ClusterTileShape
,
class
TA
,
class
AStride
,
class
ASmemLayout
,
class
TmaLoadA
,
class
TB
,
class
BStride
,
class
BSmemLayout
,
class
TmaLoadB
,
class
TC
,
class
CStride
,
class
CSmemLayout
,
class
TSFC
,
class
TiledMMA
,
bool
kEnableStochasticRounding
=
false
>
__global__
static
void
rht_gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
ClusterTileShape
cluster_tile
,
TA
const
*
A
,
AStride
dA
,
ASmemLayout
sAlayout
,
CUTE_GRID_CONSTANT
TmaLoadA
const
tma_load_a
,
TB
const
*
B
,
BStride
dB
,
BSmemLayout
sBlayout
,
CUTE_GRID_CONSTANT
TmaLoadB
const
tma_load_b
,
TC
*
C
,
CStride
dC
,
CSmemLayout
,
TSFC
*
SFC
,
TiledMMA
mma
,
float
const
*
global_amax
,
const
size_t
*
rng_state
)
{
using
namespace
cute
;
using
X
=
Underscore
;
// static constexpr bool kApplyStochasticRounding = true;
using
ElementAccumulator
=
float
;
static
constexpr
int
K_PIPE_MAX
=
size
<
3
>
(
ASmemLayout
{});
using
AtomThrShapeMNK
=
Shape
<
decltype
(
shape
<
0
>
(
typename
TiledMMA
::
ThrLayoutVMNK
{})),
_1
,
_1
>
;
static
constexpr
uint32_t
kTmaTransactionBytes
=
cutlass
::
bits_to_bytes
(
size
(
AtomThrShapeMNK
{})
*
cosize
(
take
<
0
,
3
>
(
ASmemLayout
{}))
*
cute
::
sizeof_bits_v
<
TA
>
);
static
constexpr
int
kTmaRhtTensorTransactionBytes
=
cutlass
::
bits_to_bytes
(
16
*
16
*
cute
::
sizeof_bits_v
<
TB
>
);
static
constexpr
int
AccumulatorPipelineStageCount
=
16
;
static
constexpr
int
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
using
MainloopPipeline
=
cutlass
::
PipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
Shape
<
_1
,
_1
,
_1
>
,
AtomThrShapeMNK
>
;
using
MainloopPipelineState
=
typename
MainloopPipeline
::
PipelineState
;
using
TmemAllocator
=
cute
::
TMEM
::
Allocator1Sm
;
static
constexpr
int
VectorSize
=
16
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
// Preconditions
CUTE_STATIC_ASSERT
(
is_static
<
ASmemLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
BSmemLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
CSmemLayout
>::
value
);
// Represent the full tensors
Tensor
mA
=
tma_load_a
.
get_tma_tensor
(
make_shape
(
M
,
N
));
Tensor
mB
=
tma_load_b
.
get_tma_tensor
(
make_shape
(
16
,
16
));
Tensor
mC
=
make_tensor
(
cute
::
subbyte_iterator
<
TC
>
(
C
),
make_shape
(
M
,
N
),
dC
);
// (M,N)
auto
sfc_shape
=
make_shape
(
M
,
make_shape
(
make_shape
(
Int
<
16
>
{},
_4
{}),
N
/
64
)
);
auto
sfc_stride
=
make_stride
(
N
/
16
,
make_stride
(
make_stride
(
_0
{},
_1
{}),
_4
{}
)
);
auto
sfc_layout
=
make_layout
(
sfc_shape
,
sfc_stride
);
Tensor
mSFC
=
make_tensor
(
make_gmem_ptr
(
SFC
),
sfc_layout
);
auto
cluster_shape
=
Shape
<
_1
,
_1
,
_1
>
{};
// Get the appropriate blocks for this Cluster
dim3
cluster_coord_in_grid
=
cluster_id_in_grid
();
// Total number of k-tiles
const
int
K_TILE_MAX
=
min
(
N
,
K
)
/
64
;
uint32_t
tiles_in_m
=
(
M
+
size
<
0
>
(
cluster_tile
)
-
1
)
/
size
<
0
>
(
cluster_tile
);
uint32_t
tiles_in_n
=
(
N
+
64
-
1
)
/
64
;
uint32_t
linear_tile_idx
=
blockIdx
.
x
;
uint32_t
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
uint32_t
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
auto
mainloop_tiler
=
Shape
<
_128
,
_16
,
_64
>
{};
auto
epilogue_tiler
=
Shape
<
_128
,
_64
,
_64
>
{};
Tensor
gA_mk
=
local_tile
(
mA
,
mainloop_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
gB_nk
=
local_tile
(
mB
,
cluster_tile
,
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,k)
Tensor
gC_mn
=
local_tile
(
mC
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
Tensor
gSFC_mn
=
local_tile
(
mSFC
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
// Allocate SMEM
extern
__shared__
char
shared_memory
[];
using
SharedStorage
=
SharedStorage
<
TA
,
TB
,
ASmemLayout
,
BSmemLayout
>
;
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
shared_memory
);
Tensor
tCsA
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_A
.
data
()),
sAlayout
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCsB
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_B
.
data
()),
sBlayout
);
// (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
ThrMMA
thr_mma
=
mma
.
get_slice
(
block_rank_in_cluster
);
// blk idx
Tensor
tCgB
=
thr_mma
.
partition_B
(
gB_nk
);
// (MMA,MMA_N,MMA_K,k)
auto
mma_epilogue
=
make_tiled_mma
(
SM100_MMA_F16BF16_SS
<
TA
,
TB
,
ElementAccumulator
,
128
,
64
,
UMMA
::
Major
::
MN
,
UMMA
::
Major
::
MN
>
{},
Layout
<
Shape
<
_1
,
_1
>>
{});
ThrMMA
thr_mma_epilogue
=
mma_epilogue
.
get_slice
(
block_rank_in_cluster
);
using
TiledMmaEpilogue
=
decltype
(
mma_epilogue
);
Tensor
tCgA
=
thr_mma
.
partition_A
(
gA_mk
);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor
tCrA
=
thr_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_K,PIPE)
auto
acc_shape_mma
=
partition_shape_C
(
TiledMMA
{},
take
<
0
,
2
>
(
ClusterTileShape
{}));
auto
acc_shape_epilogue
=
partition_shape_C
(
TiledMmaEpilogue
{},
take
<
0
,
2
>
(
epilogue_tiler
));
auto
bulk_tmem_mma
=
TiledMMA
::
make_fragment_C
(
append
(
acc_shape_mma
,
Int
<
AccumulatorPipelineStageCount
>
{}));
auto
bulk_tmem_epilogue
=
TiledMmaEpilogue
::
make_fragment_C
(
append
(
acc_shape_epilogue
,
Int
<
AccumulatorPipelineStageCount
/
4
>
{}));
TmemAllocator
tmem_allocator
{};
cutlass
::
arch
::
NamedBarrier
tmem_allocation_result_barrier
(
32
+
128
,
cutlass
::
arch
::
ReservedNamedBarriers
::
TmemAllocBarrier
);
Layout
cta_layout_mnk
=
make_layout
(
cluster_shape
);
Layout
cta_layout_vmnk
=
tiled_divide
(
cta_layout_mnk
,
make_tile
(
typename
TiledMMA
::
AtomThrID
{}));
auto
cta_coord_vmnk
=
cta_layout_vmnk
.
get_flat_coord
(
block_rank_in_cluster
);
auto
[
tAgA
,
tAsA
]
=
tma_partition
(
tma_load_a
,
get
<
2
>
(
cta_coord_vmnk
),
make_layout
(
size
<
2
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsA
),
group_modes
<
0
,
3
>
(
tCgA
));
auto
[
tBgB
,
tBsB
]
=
tma_partition
(
tma_load_b
,
get
<
1
>
(
cta_coord_vmnk
),
make_layout
(
size
<
1
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsB
),
group_modes
<
0
,
3
>
(
tCgB
));
uint16_t
tma_mcast_mask_a
=
create_tma_multicast_mask
<
2
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
uint16_t
tma_mcast_mask_b
=
create_tma_multicast_mask
<
1
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
bool
is_mma_warp
=
(
warp_idx
==
0
);
bool
is_dma_warp
=
(
warp_idx
==
1
);
bool
is_epilogue_warp
=
(
warp_idx
>=
4
&&
warp_idx
<=
7
);
if
(
is_epilogue_warp
&&
elect_one_sync
())
{
cute
::
prefetch
(
raw_pointer_cast
(
global_amax
));
}
typename
MainloopPipeline
::
Params
mainloop_pipeline_params
;
if
(
is_dma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_mma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Consumer
;
}
mainloop_pipeline_params
.
is_leader
=
cute
::
elect_one_sync
()
&&
is_dma_warp
;
mainloop_pipeline_params
.
transaction_bytes
=
kTmaTransactionBytes
;
mainloop_pipeline_params
.
initializing_warp
=
0
;
MainloopPipeline
mainloop_pipeline
(
shared_storage
.
mainloop
,
mainloop_pipeline_params
,
cluster_shape
,
cute
::
true_type
{},
// Perform barrier init
cute
::
true_type
{});
// Delay mask calculation
MainloopPipelineState
mainloop_pipe_consumer_state
;
MainloopPipelineState
mainloop_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
MainloopPipeline
>
();
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
4
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineState
=
typename
AccumulatorPipeline
::
PipelineState
;
AccumulatorPipelineState
accumulator_pipe_consumer_state
;
AccumulatorPipelineState
accumulator_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
AccumulatorPipeline
>
();
typename
AccumulatorPipeline
::
Params
accumulator_pipeline_params
;
if
(
is_mma_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_epilogue_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Consumer
;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params
.
producer_arv_count
=
1
;
accumulator_pipeline_params
.
consumer_arv_count
=
size
(
AtomThrShapeMNK
{})
*
128
;
accumulator_pipeline_params
.
initializing_warp
=
1
;
AccumulatorPipeline
accumulator_pipeline
(
shared_storage
.
accumulator
,
accumulator_pipeline_params
,
cluster_shape
,
cute
::
true_type
{},
// Perform barrier init
cute
::
true_type
{});
// Delay mask calculation
if
(
warp_idx
==
2
&&
elect_one_sync
())
{
cute
::
initialize_barrier
(
shared_storage
.
tma_barrier
[
0
],
/* num_threads */
1
);
}
__syncthreads
();
using
TMEM_LOAD_NEW
=
cute
::
SM100
::
TMEM
::
LOAD
::
SM100_TMEM_LOAD_32dp32b64x
;
if
(
is_dma_warp
)
{
if
(
elect_one_sync
())
{
cute
::
set_barrier_transaction_bytes
(
shared_storage
.
tma_barrier
[
0
],
kTmaRhtTensorTransactionBytes
);
copy
(
tma_load_b
.
with
(
shared_storage
.
tma_barrier
[
0
],
tma_mcast_mask_b
),
tBgB
(
_
,
0
,
0
),
tBsB
(
_
,
0
));
}
cute
::
wait_barrier
(
shared_storage
.
tma_barrier
[
0
],
0
/*tma_phase_bit*/
);
do
{
bool
is_first_wave
=
linear_tile_idx
==
blockIdx
.
x
;
uint32_t
skip_wait
=
is_first_wave
;
auto
tAgA_mk
=
tAgA
(
_
,
tile_idx_m
,
_
);
int
k_tile
=
0
;
auto
barrier_token
=
mainloop_pipeline
.
producer_try_acquire
(
mainloop_pipe_producer_state
,
skip_wait
);
CUTE_NO_UNROLL
while
(
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
)
{
int
k_tile_idx_n
=
tile_idx_n
+
k_tile
;
++
k_tile
;
skip_wait
=
(
is_first_wave
&&
k_tile
<
MainloopPipelineStageCount
);
mainloop_pipeline
.
producer_acquire
(
mainloop_pipe_producer_state
,
barrier_token
);
using
BarrierType
=
typename
MainloopPipeline
::
ProducerBarrierType
;
BarrierType
*
tma_barrier
=
mainloop_pipeline
.
producer_get_barrier
(
mainloop_pipe_producer_state
);
int
write_stage
=
mainloop_pipe_producer_state
.
index
();
++
mainloop_pipe_producer_state
;
barrier_token
=
mainloop_pipeline
.
producer_try_acquire
(
mainloop_pipe_producer_state
,
skip_wait
);
if
(
cute
::
elect_one_sync
())
{
copy
(
tma_load_a
.
with
(
*
tma_barrier
,
tma_mcast_mask_a
),
tAgA_mk
(
_
,
k_tile_idx_n
),
tAsA
(
_
,
write_stage
));
}
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
mainloop_pipeline
.
producer_tail
(
mainloop_pipe_producer_state
);
}
else
if
(
is_mma_warp
)
{
mma
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
tmem_allocation_result_barrier
.
arrive
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_mma
.
data
()
=
tmem_base_ptr
;
do
{
uint32_t
skip_wait
=
K_TILE_MAX
<=
0
;
auto
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
CUTE_NO_UNROLL
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
;
)
{
mainloop_pipeline
.
consumer_wait
(
mainloop_pipe_consumer_state
,
barrier_token
);
int
read_stage
=
mainloop_pipe_consumer_state
.
index
();
auto
tCrA_mk
=
tCrA
(
_
,
_
,
_
,
read_stage
);
auto
tCrB_nk
=
tCrB
(
_
,
_
,
0
,
0
);
CUTE_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
)
/
4
;
++
k_block
)
{
accumulator_pipeline
.
producer_acquire
(
accumulator_pipe_producer_state
);
CUTE_UNROLL
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
auto
accumulators
=
bulk_tmem_mma
(
_
,
_
,
_
,
accumulator_pipe_producer_state
.
index
()
*
4
+
i
);
gemm
(
mma
,
tCrA_mk
(
_
,
_
,
k_block
*
4
+
i
),
tCrB_nk
,
accumulators
);
}
accumulator_pipeline
.
producer_commit
(
accumulator_pipe_producer_state
);
++
accumulator_pipe_producer_state
;
}
auto
curr_mainloop_pipe_consumer_state
=
mainloop_pipe_consumer_state
;
++
mainloop_pipe_consumer_state
;
++
k_tile
;
skip_wait
=
k_tile
>=
K_TILE_MAX
;
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
mainloop_pipeline
.
consumer_release
(
curr_mainloop_pipe_consumer_state
);
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
tmem_allocator
.
release_allocation_lock
();
accumulator_pipeline
.
producer_tail
(
accumulator_pipe_producer_state
);
tmem_allocator
.
free
(
tmem_base_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
else
if
(
is_epilogue_warp
)
{
const
float
global_amax_val
=
*
global_amax
;
static
constexpr
int
FragmentSize
=
256
/
sizeof_bits_v
<
TC
>
;
tmem_allocation_result_barrier
.
arrive_and_wait
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_epilogue
.
data
()
=
tmem_base_ptr
;
int
thread_idx
=
threadIdx
.
x
%
128
;
Tensor
tCgC
=
thr_mma_epilogue
.
partition_C
(
gC_mn
);
// (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N)
auto
tiled_t2r
=
make_tmem_copy
(
TMEM_LOAD_NEW
{},
bulk_tmem_epilogue
(
_
,
_
,
_
,
_0
{}));
auto
tiled_r2g
=
make_tiled_copy_D
(
Copy_Atom
<
SM100_STORE_256bit_CACHE_NOALLOCATION
,
TC
>
{},
tiled_t2r
);
auto
thr_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
auto
thr_r2g
=
tiled_r2g
.
get_slice
(
thread_idx
);
// NVFP4 non-E8 recipe constants and global scales
static
constexpr
float
fp4_max
=
6.0
f
;
const
float
global_encode_scale
=
ComputeGlobalEncodeScaleFP4
(
global_amax_val
);
const
float
global_decode_scale
=
1.0
f
/
global_encode_scale
;
auto
sfd_converter
=
cutlass
::
NumericConverter
<
TSFC
,
float
>
{};
do
{
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
;
++
k_tile
)
{
Tensor
tCgC_mn
=
tCgC
(
_
,
_
,
_
,
tile_idx_m
,
tile_idx_n
+
k_tile
);
Tensor
tCgSFC_mn
=
gSFC_mn
(
_
,
_
,
tile_idx_m
,
tile_idx_n
+
k_tile
);
accumulator_pipeline
.
consumer_wait
(
accumulator_pipe_consumer_state
);
auto
tCtC
=
bulk_tmem_epilogue
(
_
,
_
,
_
,
accumulator_pipe_consumer_state
.
index
());
Tensor
tDtC
=
thr_t2r
.
partition_S
(
tCtC
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDgC
=
thr_t2r
.
partition_D
(
tCgC_mn
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tTR_rAcc
=
make_tensor
<
ElementAccumulator
>
(
shape
(
tDgC
));
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDrC
=
make_tensor
<
TC
>
(
shape
(
tDgC
));
Tensor
tTR_rAcc_frag
=
recast
<
cutlass
::
Array
<
ElementAccumulator
,
FragmentSize
>>
(
coalesce
(
tTR_rAcc
));
Tensor
tDrC_frag
=
recast
<
cutlass
::
Array
<
TC
,
FragmentSize
>>
(
coalesce
(
tDrC
));
Tensor
src
=
thr_r2g
.
retile_S
(
tDrC
);
Tensor
dst
=
thr_r2g
.
retile_D
(
tDgC
);
Tensor
tCgSFC
=
make_tensor
(
tCgSFC_mn
.
data
(),
make_layout
(
make_shape
(
shape
(
tCgSFC_mn
),
Int
<
1
>
{},
Int
<
1
>
{}),
make_stride
(
stride
(
tCgSFC_mn
),
Int
<
0
>
{},
Int
<
0
>
{})
));
Tensor
tDgSFC
=
filter
(
thr_t2r
.
partition_D
(
tCgSFC
));
Tensor
tDrSFC
=
make_tensor
<
TSFC
>
(
shape
(
tDgSFC
));
static
constexpr
int
NumVecs
=
size
(
tDgC
)
/
VectorSize
;
Tensor
tC_rRowSFD_frg
=
recast
<
cutlass
::
Array
<
TSFC
,
NumVecs
>>
(
tDrSFC
);
cutlass
::
maximum_absolute_value_reduction
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
,
true
>
amax_reduction
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
vec_maxs
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
pvscales
;
// TMEM_LOAD
copy
(
tiled_t2r
,
tDtC
,
tTR_rAcc
);
cutlass
::
arch
::
fence_view_async_tmem_load
();
accumulator_pipeline
.
consumer_release
(
accumulator_pipe_consumer_state
);
++
accumulator_pipe_consumer_state
;
// Cast data from FP32 to BF16 to FP32.
auto
convert_accum_to_bf16
=
cutlass
::
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
ElementAccumulator
,
FragmentSize
>
{};
auto
convert_bf16_to_accum
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
cutlass
::
bfloat16_t
,
FragmentSize
>
{};
tTR_rAcc_frag
(
_0
{})
=
convert_bf16_to_accum
(
convert_accum_to_bf16
(
tTR_rAcc_frag
(
_0
{})));
auto
compute_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
*>
(
tTR_rAcc_frag
.
data
());
auto
output_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
TC
,
VectorSize
>
*>
(
tDrC_frag
.
data
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
vec_maxs
[
v
]
=
amax_reduction
(
ElementAccumulator
(
0
),
compute_frgs
[
v
]);
}
pvscales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
vec_maxs
,
fp4_max
);
pvscales
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
pvscales
,
global_encode_scale
);
auto
pvscales_cvted
=
cutlass
::
NumericArrayConverter
<
TSFC
,
ElementAccumulator
,
NumVecs
>
{}(
pvscales
);
tC_rRowSFD_frg
(
_0
{})
=
pvscales_cvted
;
auto
qpvscale_ups
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
TSFC
,
NumVecs
>
{}(
tC_rRowSFD_frg
(
_0
{}));
auto
qpvscale_scaled
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
qpvscale_ups
,
global_decode_scale
);
auto
acc_scales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
1.0
,
qpvscale_scaled
);
// Initialize RNG for tile
const
size_t
rng_sequence
=
thread_idx
+
k_tile
*
256
+
linear_tile_idx
*
K_TILE_MAX
*
256
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
uint4
{
0
,
0
,
0
,
0
};
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
auto
acc_scale
=
cutlass
::
minimum_with_nan_propagation
<
ElementAccumulator
>
{}(
acc_scales
[
v
],
cutlass
::
platform
::
numeric_limits
<
ElementAccumulator
>::
max
());
// auto acc_scale = acc_scales[v];
if
constexpr
(
kEnableStochasticRounding
)
{
random_uint4
=
dist
.
generate4
(
rng
);
output_frgs
[
v
]
=
StochasticNumericConverter
(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
),
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
4
>*>
(
&
random_uint4
));
}
else
{
output_frgs
[
v
]
=
cutlass
::
NumericArrayConverter
<
TC
,
ElementAccumulator
,
VectorSize
>
{}(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
));
}
}
copy
(
tiled_r2g
,
src
,
dst
);
copy
(
AutoVectorizingCopyWithAssumedAlignment
<
128
>
{},
tDrSFC
,
tDgSFC
);
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
TSFC
,
bool
kEnableStochasticRounding
=
false
>
void
rht_gemm_ntt_w_sfc
(
int
m
,
int
n
,
TA
const
*
A
,
TB
const
*
B
,
TC
*
C
,
TSFC
*
SFC
,
float
const
*
global_amax
,
const
size_t
*
rng_state
,
uint32_t
sm_count
,
cudaStream_t
stream
,
int
k_tile_size
=
2048
)
{
using
namespace
cute
;
// Define shapes (dynamic)
auto
M
=
static_cast
<
int
>
(
m
);
auto
N
=
static_cast
<
int
>
(
n
);
// Define strides (mixed)
auto
dA
=
make_stride
(
Int
<
1
>
{},
m
);
// (dM,dK)
auto
dB
=
make_stride
(
Int
<
1
>
{},
16
);
// (dN,dK)
auto
dC
=
make_stride
(
n
,
Int
<
1
>
{});
// (dM,dN)
auto
cga_shape
=
Shape
<
_1
,
_1
,
_1
>
{};
auto
cga_tile_shape
=
Shape
<
_128
,
_16
,
_16
>
{};
auto
cluster_tile_mainloop
=
Shape
<
_128
,
_16
,
_64
>
{};
// Construct the MMA
auto
mma
=
make_tiled_mma
(
SM100_MMA_F16BF16_SS
<
TA
,
TB
,
float
,
128
,
16
,
UMMA
::
Major
::
MN
,
UMMA
::
Major
::
MN
>
{},
Layout
<
Shape
<
_1
,
_1
>>
{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V
(
size
(
cga_shape
)
==
size
(
mma
));
CUTE_STATIC_ASSERT_V
(
evenly_divides
(
cga_tile_shape
,
tile_shape
(
mma
)));
// Determine the A and B shapes
auto
mma_shape_B
=
partition_shape_B
(
mma
,
make_shape
(
size
<
1
>
(
cga_tile_shape
),
size
<
2
>
(
cga_tile_shape
)));
using
TiledMma
=
decltype
(
mma
);
using
AtomThrID
=
typename
TiledMma
::
AtomThrID
;
using
SmemShape_M
=
decltype
(
shape_div
(
shape
<
0
>
(
cga_tile_shape
),
shape_div
(
shape
<
0
>
(
cga_tile_shape
),
size
<
0
>
(
cga_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_N
=
decltype
(
shape_div
(
shape
<
1
>
(
cga_tile_shape
),
shape_div
(
shape
<
1
>
(
cga_tile_shape
),
size
<
1
>
(
cga_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K
=
decltype
(
cute
::
get
<
2
>
(
cga_tile_shape
));
using
SmemLayoutAtomB
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TB
,
SmemShape_N
,
SmemShape_K
>
());
auto
mma_shape_A
=
partition_shape_A
(
mma
,
make_shape
(
size
<
0
>
(
cluster_tile_mainloop
),
size
<
2
>
(
cluster_tile_mainloop
)));
using
SmemShape_M_A
=
decltype
(
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
size
<
0
>
(
cluster_tile_mainloop
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K_A
=
decltype
(
cute
::
get
<
2
>
(
cluster_tile_mainloop
));
using
SmemLayoutAtomA
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TA
,
SmemShape_M_A
,
SmemShape_K_A
>
());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr
int
kBlackwellSmemSize
=
232448
;
// 232KB in bytes
constexpr
int
kBytesPerStage
=
cute
::
size
(
mma_shape_A
)
*
sizeof
(
TA
)
+
cute
::
size
(
mma_shape_B
)
*
sizeof
(
TB
);
constexpr
int
kReservedBytes
=
256
;
// Reserve for barriers and other uses
constexpr
int
kMaxStages
=
(
kBlackwellSmemSize
-
kReservedBytes
)
/
kBytesPerStage
;
auto
sP
=
Int
<
kMaxStages
>
{};
// SMEM pipelines
auto
sA
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomA
{},
append
(
mma_shape_A
,
sP
));
// (MMA,MMA_M,MMA_K,PIPE)
auto
sB
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomB
{},
append
(
mma_shape_B
,
sP
));
// (MMA,MMA_N,MMA_K,PIPE)
auto
sC
=
Layout
<
_1
>
{};
// XXX Dummy
// Create GMEM tensors
Tensor
tensorA
=
make_tensor
(
A
,
make_layout
(
make_shape
(
M
,
N
),
dA
));
// (M,N)
Tensor
tensorB
=
make_tensor
(
B
,
make_layout
(
make_shape
(
16
,
16
),
dB
));
// (16,16)
// Create the TiledCopy
auto
tma_load_a
=
make_tma_copy_A_sm100
(
SM90_TMA_LOAD
{},
tensorA
,
sA
(
_
,
_
,
_
,
0
),
cluster_tile_mainloop
,
mma
);
auto
tma_load_b
=
make_tma_copy_B_sm100
(
SM90_TMA_LOAD
{},
tensorB
,
sB
(
_
,
_
,
_
,
0
),
cga_tile_shape
,
mma
);
// Assert checks on tile sizes -- no predication
NVTE_CHECK
(
M
%
size
<
0
>
(
cga_tile_shape
)
==
0
,
"Inner dimension must be divisible by "
,
static_cast
<
size_t
>
(
size
<
0
>
(
cga_tile_shape
)),
" but got "
,
M
,
"."
);
NVTE_CHECK
(
N
%
(
4
*
size
<
1
>
(
cga_tile_shape
))
==
0
,
"Outer dimension must be divisible by "
,
4
*
static_cast
<
size_t
>
(
size
<
1
>
(
cga_tile_shape
)),
" but got "
,
N
,
"."
);
uint32_t
tiles
=
size
(
ceil_div
(
M
,
get
<
0
>
(
cga_tile_shape
)))
*
size
(
ceil_div
(
N
,
k_tile_size
));
tiles
=
(
tiles
<
sm_count
)
?
tiles
:
sm_count
;
dim3
dimBlock
(
256
);
dim3
dimCluster
(
size
<
0
>
(
cga_shape
),
size
<
1
>
(
cga_shape
),
size
<
2
>
(
cga_shape
));
dim3
dimGrid
(
tiles
,
1
,
1
);
int
smem_size
=
sizeof
(
SharedStorage
<
TA
,
TB
,
decltype
(
sA
),
decltype
(
sB
)
>
);
auto
*
kernel_ptr
=
&
rht_gemm_device
<
decltype
(
M
),
decltype
(
N
),
decltype
(
k_tile_size
),
decltype
(
cga_tile_shape
),
TA
,
decltype
(
dA
),
decltype
(
sA
),
decltype
(
tma_load_a
),
TB
,
decltype
(
dB
),
decltype
(
sB
),
decltype
(
tma_load_b
),
TC
,
decltype
(
dC
),
decltype
(
sC
),
TSFC
,
decltype
(
mma
),
kEnableStochasticRounding
>
;
bool
status
=
cudaFuncSetAttribute
(
*
kernel_ptr
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
status
!=
cudaSuccess
)
{
std
::
cerr
<<
"Error: Failed to set Shared Memory size."
<<
std
::
endl
;
return
;
}
(
*
kernel_ptr
)
<<<
dimGrid
,
dimBlock
,
smem_size
,
stream
>>>
(
M
,
N
,
k_tile_size
,
cga_tile_shape
,
A
,
dA
,
sA
,
tma_load_a
,
B
,
dB
,
sB
,
tma_load_b
,
C
,
dC
,
sC
,
SFC
,
mma
,
global_amax
,
rng_state
);
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
TSFC
,
bool
kEnableStochasticRounding
=
false
>
void
rht_gemm_ttt_wrapper
(
int
m
,
int
n
,
TA
const
*
A
,
TB
const
*
B
,
TC
*
C
,
TSFC
*
SFC
,
float
const
*
global_amax
,
const
size_t
*
rng_state
,
uint32_t
sm_count
,
cudaStream_t
stream
,
int
k_tile_size
=
1024
)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc
<
TA
,
TB
,
TC
,
TSFC
,
kEnableStochasticRounding
>
(
n
,
m
,
A
,
B
,
C
,
SFC
,
global_amax
,
rng_state
,
sm_count
,
stream
,
k_tile_size
);
}
}
// namespace
}
// namespace detail
// clang-format on
void
hadamard_transform_cast_fusion_columnwise
(
const
Tensor
&
input_
,
Tensor
&
output_
,
const
Tensor
&
hadamard_matrix_
,
QuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform_cast_fusion_columnwise
);
// Check input and output tensors
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
const
SimpleTensor
&
input
=
input_
.
data
;
SimpleTensor
&
global_amax
=
output_
.
amax
;
SimpleTensor
&
output_t
=
output_
.
data
;
SimpleTensor
&
scale_inv_t
=
output_
.
scale_inv
;
// Stochastic rounding config
const
bool
use_stochastic_rounding
=
quant_config
.
stochastic_rounding
;
const
size_t
*
rng_state
=
nullptr
;
if
(
quant_config
.
rng_state
!=
nullptr
)
{
Tensor
&
rng_state_tensor
=
*
convertNVTETensor
(
quant_config
.
rng_state
);
NVTE_CHECK
(
rng_state_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_tensor
.
data
.
dptr
);
}
// Template arguments
using
TA
=
cute
::
bfloat16_t
;
using
TB
=
cute
::
bfloat16_t
;
using
TC
=
cutlass
::
float_e2m1_t
;
using
TSFC
=
cutlass
::
float_ue4m3_t
;
checkCuDriverContext
(
stream
);
// Check Hadamard matrix
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
hadamard_matrix_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Hadamard matrix must be BF16 tensor, but scaling mode is "
,
to_string
(
hadamard_matrix_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
hadamard_matrix_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Hadamard matrix must be BF16 tensor, but dtype is "
,
to_string
(
hadamard_matrix_
.
dtype
()),
"."
);
const
SimpleTensor
&
hadamard_matrix
=
hadamard_matrix_
.
data
;
NVTE_CHECK
(
(
hadamard_matrix_
.
shape
()
==
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
}),
"Hadamard matrix must have shape="
,
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
},
", but got shape="
,
hadamard_matrix_
.
shape
(),
"."
);
const
size_t
hadamard_dimension
=
hadamard_matrix
.
shape
[
0
];
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
n
=
input
.
shape
[
ndim
-
1
];
size_t
m
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
m
*=
input
.
shape
[
i
];
}
auto
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
NVTE_CHECK
(
n
%
hadamard_dimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
m
%
hadamard_dimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
int
k_tile_size
=
1024
;
if
(
m
==
8192
&&
n
==
5120
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
8192
&&
n
==
10240
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
8192
&&
n
==
2560
)
{
k_tile_size
=
1280
;
}
else
if
(
m
==
8192
&&
n
==
11328
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
8192
&&
n
==
512
)
{
k_tile_size
=
256
;
}
else
if
(
m
==
8192
&&
n
==
3584
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
11328
&&
n
==
8192
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
5120
&&
n
==
8192
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
10240
&&
n
==
8192
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
2560
&&
n
==
8192
)
{
k_tile_size
=
1280
;
}
else
if
(
m
==
512
&&
n
==
8192
)
{
k_tile_size
=
256
;
}
else
if
(
m
==
3584
&&
n
==
8192
)
{
k_tile_size
=
512
;
}
else
if
(
m
<
1024
||
n
<
1024
)
{
k_tile_size
=
512
;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kUseStochasticRounding
,
detail
::
rht_gemm_ttt_wrapper
<
TA
,
TB
,
TC
,
TSFC
,
kUseStochasticRounding
>
(
/*m=*/
m
,
/*n=*/
n
,
/*A=*/
reinterpret_cast
<
TA
const
*>
(
input
.
dptr
),
/*B=*/
reinterpret_cast
<
TB
const
*>
(
hadamard_matrix
.
dptr
),
/*C=*/
reinterpret_cast
<
TC
*>
(
output_t
.
dptr
),
/*SFC=*/
reinterpret_cast
<
TSFC
*>
(
scale_inv_t
.
dptr
),
/*global_amax=*/
reinterpret_cast
<
float
const
*>
(
global_amax
.
dptr
),
/*rng_state=*/
rng_state
,
/*sm_count=*/
sm_count
,
/*stream=*/
stream
,
/*k_tile_size=*/
k_tile_size
););
}
}
// namespace transformer_engine
void
nvte_hadamard_transform_cast_fusion_columnwise
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform_cast_fusion_columnwise
);
using
namespace
transformer_engine
;
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
hadamard_transform_cast_fusion_columnwise
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
hadamard_matrix
),
quant_config_cpp
,
stream
);
}
transformer_engine/common/include/transformer_engine/activation.h
View file @
063ef88d
...
...
@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU
,
SRELU
,
SREGLU
,
CLAMPED_SWIGLU
};
/*! \brief Computes the GeLU activation of the input.
...
...
@@ -173,6 +174,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_clamped_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
float
limit
,
float
alpha
,
cudaStream_t
stream
);
/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -230,6 +251,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void
nvte_dswiglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_clamped_dswiglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
float
limit
,
float
alpha
,
cudaStream_t
stream
);
/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
063ef88d
...
...
@@ -67,6 +67,11 @@ class CommOverlapCore {
std
::
vector
<
cudaStream_t
>
_stream_compute
;
cudaEvent_t
_start_compute
,
_stop_compute
,
_start_comm
,
_stop_comm
,
_comm_launch_event
;
private:
void
initialize
(
int
tp_size
,
int
num_splits
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
use_ce
,
bool
atomic_gemm
);
public:
CommOverlapCore
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual
~
CommOverlapCore
();
void
*
get_ubuf_dptr
()
{
return
_ubuf
.
dptr
();
}
void
set_ubuf_scale_inv
(
float
*
scale_inv
)
{
_ubuf_scale_inv
=
scale_inv
;
_ubuf_scale_inv_initialized
=
true
;
}
virtual
void
copy_into_buffer
(
cudaStream_t
stream
,
const
TensorWrapper
&
source
,
bool
local_chunk
,
bool
rowwise
=
true
)
{
NVTE_ERROR
(
"Operation is not implemented."
);
}
TensorWrapper
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
offset
,
const
std
::
vector
<
size_t
>
&
shape
);
TensorWrapper
get_buffer_chunk_like
(
const
TensorWrapper
&
source
,
size_t
offset
,
const
std
::
vector
<
size_t
>
&
shape
);
int
get_tp_size
()
{
return
_tp_size
;
}
bool
is_atomic_gemm
()
{
return
_atomic_gemm
;
}
bool
is_p2p_overlap
()
{
return
_is_p2p
;
}
...
...
@@ -150,6 +164,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t
_stream_comm
;
cudaEvent_t
_start_d2dcopy
;
private:
void
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
bool
rs_overlap_first_gemm
);
public:
CommOverlapBase
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -228,6 +246,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t
_stream_recv
;
cudaEvent_t
_stop_send
,
_stop_recv
;
private:
void
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
CommOverlapType
comm_type
,
bool
aggregate
);
public:
CommOverlapP2PBase
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -241,6 +263,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual
~
CommOverlapP2PBase
();
void
copy_into_buffer
(
cudaStream_t
stream
,
const
TensorWrapper
&
source
,
bool
local_chunk
,
bool
rowwise
=
true
)
override
;
TensorWrapper
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
buffer_id
);
void
bulk_overlap
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
063ef88d
...
...
@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
=
5
,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX
=
0
,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX
=
1
,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX
=
2
,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
...
...
@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
...
...
@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
/*! \brief Compute dot product attention with packed QKV input.
*
...
...
@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
...
...
@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
...
...
@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void
nvte_fused_attn_bwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
...
...
@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
...
...
@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...
...
@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void
nvte_fused_attn_bwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q_padded
,
const
NVTETensor
cu_seqlens_
kv
_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTETensor
dKV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_
q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q
_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with separate Q, K and V.
*
...
...
@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
...
...
@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...
...
@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void
nvte_fused_attn_bwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dK
,
NVTETensor
dV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
NVTETensor
dV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
*
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
063ef88d
...
...
@@ -15,9 +15,76 @@
#ifdef __cplusplus
extern
"C"
{
#endif
#endif
// __cplusplus
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
/*! \brief Configuration for matrix multiplication. */
typedef
void
*
NVTEMatmulConfig
;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
enum
NVTEMatmulConfigAttribute
{
/*! Bias tensor
*
* If provided, the bias tensor is applied in the GEMM epilogue.
*/
kNVTEMatmulConfigBiasTensor
=
0
,
/*! Bias gradient tensor
*
* If provided, the bias gradient tensor will be filled in the GEMM epilogue.
*/
kNVTEMatmulConfigDBiasTensor
=
1
,
/*! Whether to compute GELU in GEMM epilogue. */
kNVTEMatmulConfigWithGELUEpilogue
=
2
,
/*! Whether to compute GELU backward in GEMM epilogue. */
kNVTEMatmulConfigWithDGELUEpilogue
=
3
,
/*! Auxilliary tensor for GEMM epilogue.
*
* For GELU, this will be filled with the GELU input. For GELU
* backward, this is expected to already be filled with the GELU
* input.
*/
kNVTEMatmulConfigEpilogueAuxTensor
=
4
,
/*! Whether to use split accumulator for FP8 GEMM. */
kNVTEMatmulConfigUseSplitAccumulator
=
5
,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEMatmulConfigSMCount
=
6
,
kNVTEMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig
nvte_create_matmul_config
();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void
nvte_get_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void
nvte_set_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
);
/*! \brief Destroy a matrix multiplication configuration. */
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
...
...
@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = alpha * op(A) * op(B) + beta * C`
*
* \param[in] transa Whether to transpose A matrix.
* \param[in] transb Whether to transpose B matrix.
* \param[in] alpha Scaling factor applied to matmul output.
* \param[in] A A matrix.
* \param[in] B B matrix.
* \param[in] beta Scaling factor applied to C matrix.
* \param[in] C C matrix.
* \param[out] D Output matrix.
* \param[in] workspace Workspace tensor.
* \param[in] config Additional configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_cublas_gemm_v2
(
int
transa
,
int
transb
,
const
float
*
alpha
,
const
NVTETensor
A
,
const
NVTETensor
B
,
const
float
*
beta
,
const
NVTETensor
C
,
NVTETensor
D
,
NVTETensor
workspace
,
NVTEMatmulConfig
config
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
,
bool
nvte_use_rocblas
,
int
compute_stream_offset
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
* allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
...
...
@@ -133,9 +223,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
...
...
@@ -160,7 +250,9 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // __cplusplus
#ifdef __cplusplus
/*! \namespace transformer_engine
*/
...
...
@@ -178,6 +270,89 @@ constexpr int num_batchgemm_streams = 1;
void
nvte_cublas_handle_init
();
/*! \struct MatmulConfigWrapper
* \brief C++ wrapper for NVTEMatmulConfig.
*/
class
MatmulConfigWrapper
{
public:
MatmulConfigWrapper
()
:
config_
{
nvte_create_matmul_config
()}
{}
MatmulConfigWrapper
(
const
MatmulConfigWrapper
&
)
=
delete
;
MatmulConfigWrapper
&
operator
=
(
const
MatmulConfigWrapper
&
)
=
delete
;
MatmulConfigWrapper
(
MatmulConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
MatmulConfigWrapper
&
operator
=
(
MatmulConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_matmul_config
(
config_
);
}
config_
=
other
.
config_
;
other
.
config_
=
nullptr
;
return
*
this
;
}
~
MatmulConfigWrapper
()
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_matmul_config
(
config_
);
config_
=
nullptr
;
}
}
/*! \brief Get the underlying NVTEMatmulConfig.
*
* \return NVTEMatmulConfig held by this MatmulConfigWrapper.
*/
operator
NVTEMatmulConfig
()
const
noexcept
{
return
config_
;
}
/*! \brief Set bias tensor. */
void
set_bias_tensor
(
NVTETensor
bias_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigBiasTensor
,
&
bias_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set bias gradient tensor. */
void
set_dbias_tensor
(
NVTETensor
dbias_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigDBiasTensor
,
&
dbias_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void
set_with_gelu_epilogue
(
bool
with_gelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithGELUEpilogue
,
&
with_gelu_epilogue
,
sizeof
(
bool
));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void
set_with_dgelu_epilogue
(
bool
with_dgelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithDGELUEpilogue
,
&
with_dgelu_epilogue
,
sizeof
(
bool
));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
void
set_epilogue_aux_tensor
(
NVTETensor
epilogue_aux_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigEpilogueAuxTensor
,
&
epilogue_aux_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void
set_use_split_accumulator
(
bool
use_split_accumulator
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigUseSplitAccumulator
,
&
use_split_accumulator
,
sizeof
(
bool
));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void
set_sm_count
(
int
sm_count
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigSMCount
,
&
sm_count
,
sizeof
(
int
));
}
private:
/*! \brief Wrapped NVTEMatmulConfig. */
NVTEMatmulConfig
config_
=
nullptr
;
};
}
// namespace transformer_engine
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_GEMM_H_
transformer_engine/common/include/transformer_engine/hadamard_transform.h
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file hadamard_transform.h
* \brief Functions for Hadamard transforms.
*/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#endif
/*! \brief Perform a randomized Hadamard transform on the input tensor.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*! \brief Perform the absolute maximum reduction on the input tensor with/without
* randomized hadamard transform. The rowwise result is the absolute maximum
* of the input tensor. The columnwise result is the absolute maximum of the
* input tensor transposed and applied randomized hadamard transformation.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*! \brief Perform the columnwise hadamard transform cast fusion.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] hadamard_matrix Hadamard matrix.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform_cast_fusion_columnwise
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
transformer_engine/common/include/transformer_engine/recipe.h
View file @
063ef88d
...
...
@@ -124,6 +124,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t
start_offset
,
size_t
block_len
,
const
NVTEDType
out_dtype
,
cudaStream_t
stream
);
void
nvte_nvfp4_compute_per_tensor_scale
(
const
NVTETensor
inpA
,
const
bool
use_rowwise_amax_A
,
const
NVTETensor
inpB
,
const
bool
use_rowwise_amax_B
,
float
alpha_in
,
NVTETensor
alpha_out
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/swizzle.h
View file @
063ef88d
...
...
@@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
void
nvte_multi_tensor_swizzle_scaling_factors
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
const
size_t
num_tensors
,
cudaStream_t
stream
);
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it
* not natively supported by cublasLt on architectures other than Hopper.
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
* */
void
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
063ef88d
...
...
@@ -73,6 +73,7 @@ enum NVTETensorParam {
kNVTEAmax
=
3
,
/*!< Amax tensor */
kNVTERowwiseScaleInv
=
4
,
/*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv
=
5
,
/*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax
=
6
,
/*!< Columnwise Amax tensor */
kNVTENumTensorParams
};
...
...
@@ -95,10 +96,9 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D
=
2
,
NVTE_BLOCK_SCALING_2D
=
3
,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
=
4
,
/*! Single scale per block of 16 elements consecutive in either
* rowwise or columnwise direction */
NVTE_NVFP4_1D_SCALING
=
4
,
NVTE_INVALID_SCALING
=
100
};
...
...
@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
=
3
,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState
=
4
,
/*! Whether to use 2D block scaling for NVFP4 */
kNVTEQuantizationConfigNVFP42DQuantization
=
5
,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding
=
6
,
kNVTEQuantizationConfigNumAttributes
};
...
...
@@ -458,6 +464,15 @@ inline bool is_fp4_dtype(const DType t) {
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_high_precision_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat32
||
t
==
DType
::
kBFloat16
||
t
==
DType
::
kFloat16
;
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
...
...
@@ -593,6 +608,11 @@ class TensorWrapper {
return
set_parameter
(
kNVTEColumnwiseScaleInv
,
dptr
,
type
,
shape
);
}
template
<
typename
ShapeType
>
TensorWrapper
&
set_columnwise_amax
(
void
*
dptr
,
DType
type
,
const
ShapeType
&
shape
)
noexcept
{
return
set_parameter
(
kNVTEColumnwiseAmax
,
dptr
,
type
,
shape
);
}
// Parameter getters
NVTEBasicTensor
get_parameter
(
const
NVTETensorParam
param
)
const
noexcept
{
...
...
@@ -617,6 +637,10 @@ class TensorWrapper {
return
get_parameter
(
kNVTEColumnwiseScaleInv
);
}
NVTEBasicTensor
get_columnwise_amax
()
const
noexcept
{
return
get_parameter
(
kNVTEColumnwiseAmax
);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
...
...
@@ -865,6 +889,24 @@ class QuantizationConfigWrapper {
&
format
,
sizeof
(
Float8BlockScaleTensorFormat
));
}
/*! \brief Set stochastic rounding state */
void
set_rng_state
(
NVTETensor
rng_state
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigRNGState
,
&
rng_state
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void
set_nvfp4_2d_quantization
(
bool
nvfp4_2d_quantization
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigNVFP42DQuantization
,
&
nvfp4_2d_quantization
,
sizeof
(
bool
));
}
/*! \brief Set whether to use stochastic rounding */
void
set_stochastic_rounding
(
bool
stochastic_rounding
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigStochasticRounding
,
&
stochastic_rounding
,
sizeof
(
bool
));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig
config_
=
nullptr
;
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
063ef88d
...
...
@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
...
...
@@ -65,11 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_mxfp_scaling
(
z
->
scaling_mode
),
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
063ef88d
...
...
@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
...
...
@@ -51,11 +51,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_mxfp_scaling
(
z
->
scaling_mode
),
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by mxfp scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
...
...
transformer_engine/common/recipe/__init__.py
View file @
063ef88d
...
...
@@ -4,10 +4,10 @@
"""This module provides predefined FP8 recipes."""
from
__future__
import
annotations
import
warnings
import
os
from
enum
import
Enum
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
typing
import
Any
,
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
from
dataclasses
import
field
from
pydantic.dataclasses
import
dataclass
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
@@ -23,9 +23,12 @@ class _FormatHelper(NamedTuple):
class
Format
(
Enum
):
"""
Supported FP8 formats.
Supported FP4 formats.
Values
------
E2M1 :
All FP4 tensors are in e2m1 format
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
...
...
@@ -35,6 +38,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
"""
E2M1
=
_FormatHelper
(
max_fwd
=
6
,
max_bwd
=
6
)
E4M3
=
_FormatHelper
(
max_fwd
=
448
,
max_bwd
=
448
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
...
...
@@ -42,9 +46,13 @@ class Format(Enum):
@
dataclass
(
frozen
=
True
)
class
MMParams
:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""Matrix multiplication options.
Parameters
----------
use_split_accumulator : bool, default = `True`
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
use_split_accumulator
:
bool
=
True
...
...
@@ -55,10 +63,24 @@ class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
"""
power_2_scale
:
bool
=
False
amax_epsilon
:
float
=
0.0
random_hadamard_transform
:
bool
=
False
stochastic_rounding
:
bool
=
False
fp4_2d_quantization
:
bool
=
False
def
__repr__
(
self
)
->
str
:
return
(
f
"Qparams(
\n
power_2_scale=
{
self
.
power_2_scale
}
,
\n
"
f
"amax_epsilon=
{
self
.
amax_epsilon
}
,
\n
"
f
"random_hadamard_transform=
{
self
.
random_hadamard_transform
}
,
\n
"
f
"stochastic_rounding=
{
self
.
stochastic_rounding
}
,
\n
"
f
"fp4_2d_quantization=
{
self
.
fp4_2d_quantization
}
\n
)"
)
class
Recipe
:
...
...
@@ -66,6 +88,10 @@ class Recipe:
Base recipe class.
"""
def
nvfp4
(
self
):
"""Whether the given recipe is NVFP4 1D block scaling."""
return
isinstance
(
self
,
NVFP4BlockScaling
)
def
mxfp8
(
self
):
"""Whether the given recipe is MXFP8 block scaling."""
return
isinstance
(
self
,
MXFP8BlockScaling
)
...
...
@@ -86,6 +112,10 @@ class Recipe:
"""Whether the given recipe is float8 blockwise scaling."""
return
isinstance
(
self
,
Float8BlockScaling
)
def
custom
(
self
):
"""Whether the given recipe is custom."""
return
isinstance
(
self
,
CustomRecipe
)
@
dataclass
()
class
DelayedScaling
(
Recipe
):
...
...
@@ -131,7 +161,7 @@ class DelayedScaling(Recipe):
where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `
fp8
_group` (specified in the `
fp8_
autocast`
tensors is reduced across the `
amax_reduction
_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
GPU maintains local amaxes and scaling factors. To ensure results are
...
...
@@ -139,7 +169,7 @@ class DelayedScaling(Recipe):
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`
fp8_
autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
...
...
@@ -184,6 +214,7 @@ class DelayedScaling(Recipe):
f
"margin=
{
self
.
margin
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"amax_history_len=
{
self
.
amax_history_len
}
, "
f
"reduce_amax=
{
self
.
reduce_amax
}
, "
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
"
)
...
...
@@ -201,10 +232,11 @@ class Float8CurrentScaling(Recipe):
pass.
"""
use_power_2_scales
:
bool
=
os
.
getenv
(
"NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES"
,
"0"
)
==
"1"
fp8_format
:
Format
=
Format
.
HYBRID
fp8_quant_fwd_inp
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_weight
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_bwd_grad
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_inp
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_weight
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_quant_bwd_grad
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_gemm_fprop
:
MMParams
=
MMParams
(
use_split_accumulator
=
False
)
fp8_gemm_dgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
fp8_gemm_wgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
...
...
@@ -213,9 +245,6 @@ class Float8CurrentScaling(Recipe):
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
assert
(
not
self
.
fp8_dpa
and
not
self
.
fp8_mha
),
"FP8 attention is not supported for Float8CurrentScaling."
def
__repr__
(
self
)
->
str
:
return
(
...
...
@@ -334,6 +363,7 @@ class Float8BlockScaling(Recipe):
assert
(
not
self
.
fp8_dpa
and
not
self
.
fp8_mha
),
"FP8 attention is not supported for Float8BlockScaling."
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
def
__repr__
(
self
)
->
str
:
return
(
...
...
@@ -351,3 +381,134 @@ class Float8BlockScaling(Recipe):
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
"
)
@
dataclass
()
class
NVFP4BlockScaling
(
Recipe
):
"""
Use the NVFP4 scaling strategy.
This is a 2-level block scaling strategy. In level 1, each group of
16 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E4M3 (4 bits of exponent,
3 bits of mantissa). In level 2, a global per tensor FP32 scaling
factor is used to scale the entire tensor.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
The default NVFP4 training recipe implements 3 techniques for quantizing
to a narrow format (4-bit):
- For weight tensors a variant of the NVFP4 quantization is used,
where a single scaling factor is shared by a 2D block of 16x16 elements.
- When quantizing gradients, stochastic rounding is applied to avoid the bias
introduced by quantization. With this, values are rounded probabilistically
to one of their two nearest representable numbers, with probabilities
inversely proportional to their distances.
- When quantizing inputs and gradients, random Hadamard transforms are applied
(16x16 Hadamard matrix) to smooth outliers in the tensor distributions
and make them easier to represent accurately in NVFP4.
These techniques are described more comprehensively in the NVFP4 paper titled
'Pretraining Large Language Models with NVFP4' (https://arxiv.org/abs/2509.25149v1).
Parameters
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default = `False`
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default = `False`
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = `False`
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
# Configuration envvars
disable_rht
:
bool
=
os
.
getenv
(
"NVTE_NVFP4_DISABLE_RHT"
,
"0"
)
==
"1"
disable_stochastic_rounding
:
bool
=
(
os
.
getenv
(
"NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING"
,
"0"
)
==
"1"
)
disable_2d_quantization
:
bool
=
os
.
getenv
(
"NVTE_NVFP4_DISABLE_2D_QUANTIZATION"
,
"0"
)
==
"1"
fp4_format
:
Format
=
Format
.
E2M1
fp8_format
:
Format
=
Format
.
E4M3
# Not applying quantization to attention for now
fp8_dpa
:
bool
=
False
fp8_mha
:
bool
=
False
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp4_format
==
Format
.
E2M1
,
"Only E2M1 is supported for NVFP4 scaling"
assert
self
.
fp8_format
==
Format
.
E4M3
,
"Only E4M3 is supported for NVFP4 scaling"
# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self
.
fp4_quant_fwd_inp
=
QParams
(
random_hadamard_transform
=
not
self
.
disable_rht
,
stochastic_rounding
=
False
,
fp4_2d_quantization
=
False
,
)
self
.
fp4_quant_fwd_weight
=
QParams
(
random_hadamard_transform
=
False
,
stochastic_rounding
=
False
,
fp4_2d_quantization
=
not
self
.
disable_2d_quantization
,
)
self
.
fp4_quant_bwd_grad
=
QParams
(
random_hadamard_transform
=
not
self
.
disable_rht
,
stochastic_rounding
=
not
self
.
disable_stochastic_rounding
,
fp4_2d_quantization
=
False
,
)
def
__repr__
(
self
)
->
str
:
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"fp4_format=
{
str
(
self
.
fp4_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
, "
f
"fp4_quant_fwd_inp=
{
self
.
fp4_quant_fwd_inp
}
, "
f
"fp4_quant_fwd_weight=
{
self
.
fp4_quant_fwd_weight
}
, "
f
"fp4_quant_bwd_grad=
{
self
.
fp4_quant_bwd_grad
}
, "
)
@
dataclass
()
class
CustomRecipe
(
Recipe
):
"""
Custom recipe that allows users to provide quantizer factories.
.. warning::
**EXPERIMENTAL**: Custom recipe is experimental, still under active development,
and the API is subject to change without notice. Use at your own risk.
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory
:
Callable
[...,
Any
]
fp8_dpa
:
bool
=
False
fp8_mha
:
bool
=
False
def
__repr__
(
self
)
->
str
:
return
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, qfactory=
{
self
.
qfactory
}
"
transformer_engine/common/recipe/current_scaling.cu
View file @
063ef88d
...
...
@@ -27,6 +27,13 @@ namespace {
constexpr
int
amax_kernel_threads
=
512
;
__launch_bounds__
(
1
)
__global__
void
zero_amax_kernel
(
float
*
amax_ptr
,
const
float
*
noop_ptr
)
{
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
*
amax_ptr
=
0
;
}
template
<
int
nvec
,
bool
aligned
,
typename
InputType
>
__launch_bounds__
(
amax_kernel_threads
)
__global__
void
amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
...
...
@@ -131,7 +138,8 @@ template <int nvec, typename InputType>
void
launch_amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
const
float
*
noop_ptr
,
cudaStream_t
stream
)
{
// Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
));
zero_amax_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
amax
,
noop_ptr
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
// Return immediately if tensor is empty
if
(
N
==
0
)
{
...
...
@@ -216,15 +224,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK
(
output_
!=
nullptr
,
"Invalid output tensor (got NULL)"
);
auto
&
output
=
*
convertNVTETensorCheck
(
output_
);
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
output
.
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling or "
"NVFP4 1D scaling, "
"but got scaling_mode="
,
to_string
(
output
.
scaling_mode
));
NVTE_CHECK
(
output
.
amax
.
numel
()
==
1
,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape="
,
output
.
amax
.
shape
,
")"
);
NVTE_CHECK
(
output
.
amax
.
dptr
!=
nullptr
,
NVTE_CHECK
(
output
.
amax
.
dptr
!=
nullptr
||
output
.
columnwise_amax
.
dptr
!=
nullptr
,
"Output tensor for amax computation has amax tensor without data"
);
NVTE_CHECK
(
output
.
amax
.
dtype
==
DType
::
kFloat32
,
"Output tensor for amax computation has invalid amax tensor "
...
...
@@ -243,11 +253,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
}
// Compute amax
float
*
amax_ptr
=
reinterpret_cast
<
float
*>
(
(
output
.
amax
.
dptr
!=
nullptr
)
?
output
.
amax
.
dptr
:
output
.
columnwise_amax
.
dptr
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
amax
.
dptr
),
input
.
data
.
numel
(),
noop_ptr
,
stream
););
// NOLINT(*)
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
amax_ptr
,
input
.
data
.
numel
(),
noop_ptr
,
stream
););
// NOLINT(*)
}
}
// anonymous namespace
...
...
transformer_engine/common/recipe/nvfp4.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace
transformer_engine
{
namespace
nvfp4_recipe
{
// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0;
constexpr
float
factor_inv
=
1.0
/
(
6.0
*
6.0
*
448.0
*
448.0
);
// Kernel to compute alpha *= amax_A * amax_B / factor
__global__
void
compute_nvfp4_per_tensor_scale_kernel
(
float
alpha_in
,
const
float
*
amax_A
,
const
float
*
amax_B
,
float
*
alpha_out
)
{
// factor is defined in the enclosing namespace
*
alpha_out
=
alpha_in
*
(
*
amax_A
)
*
(
*
amax_B
)
*
factor_inv
;
}
}
// namespace nvfp4_recipe
}
// namespace transformer_engine
void
nvte_nvfp4_compute_per_tensor_scale
(
const
NVTETensor
inpA
,
const
bool
use_rowwise_amax_A
,
const
NVTETensor
inpB
,
const
bool
use_rowwise_amax_B
,
float
alpha_in
,
NVTETensor
alpha_out
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_nvfp4_compute_per_tensor_scale
);
using
namespace
transformer_engine
;
auto
*
tA
=
convertNVTETensor
(
inpA
);
auto
*
tB
=
convertNVTETensor
(
inpB
);
auto
*
tOut
=
convertNVTETensor
(
alpha_out
);
void
*
amax_A_ptr
=
use_rowwise_amax_A
?
tA
->
amax
.
dptr
:
tA
->
columnwise_amax
.
dptr
;
void
*
amax_B_ptr
=
use_rowwise_amax_B
?
tB
->
amax
.
dptr
:
tB
->
columnwise_amax
.
dptr
;
void
*
alpha_ptr
=
tOut
->
data
.
dptr
;
// check for not null pointers
NVTE_CHECK
(
amax_A_ptr
!=
nullptr
,
"amax_A_ptr is null"
);
NVTE_CHECK
(
amax_B_ptr
!=
nullptr
,
"amax_B_ptr is null"
);
NVTE_CHECK
(
alpha_ptr
!=
nullptr
,
"alpha_ptr is null"
);
nvfp4_recipe
::
compute_nvfp4_per_tensor_scale_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
alpha_in
,
reinterpret_cast
<
const
float
*>
(
amax_A_ptr
),
reinterpret_cast
<
const
float
*>
(
amax_B_ptr
),
reinterpret_cast
<
float
*>
(
alpha_ptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
transformer_engine/common/swizzle/swizzle.cu
View file @
063ef88d
...
...
@@ -18,7 +18,9 @@
namespace
transformer_engine
{
namespace
{
constexpr
__device__
__host__
int
MXFP8_BLOCK_SIZE
=
32
;
constexpr
int
MXFP8_BLOCK_SIZE
=
32
;
constexpr
int
NVFP4_BLOCK_SIZE
=
16
;
constexpr
__device__
__host__
int
TB_DIM
=
32
;
constexpr
__device__
__host__
int
NEW_SF_TILE_DIM_K
=
16
;
constexpr
__device__
__host__
int
N_SF_PER_TD_PER_TILE
=
4
;
...
...
@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
const
int
original_K
=
kernel_args
.
original_k_list
[
tensor_id
];
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
// Get block index in grid. Emulate 2D grid.
const
int
num_tiles_k
=
K
/
SF_TILE_DIM_K
;
...
...
@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
}
// namespace
void
swizzle_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
if
(
!
is_fp8_dtype
(
input
->
dtype
())
||
is_delayed_tensor_scaling
(
input
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented caling mode "
+
to_string
(
input
->
scaling_mode
)
+
"."
);
}
NVTE_CHECK
(
input
->
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
input
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
input
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
||
input
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Input tensor has invalid scaling mode ("
,
to_string
(
input
->
scaling_mode
),
")."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
->
dtype
())
||
is_fp4_dtype
(
input
->
dtype
()),
"Input tensor has invalid dtype ("
,
to_string
(
input
->
dtype
()),
")."
);
// Do nothing if tensor is empty
if
(
input
->
data
.
numel
()
==
0
)
{
...
...
@@ -345,176 +349,202 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
CheckInputTensor
(
*
output
,
"scaling_factor_output"
);
auto
&
scaling_mode
=
input
->
scaling_mode
;
NVTE_CHECK
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Unsupported scaling mode for swizzling."
);
bool
nvfp4
=
scaling_mode
==
NVTE_NVFP4_1D_SCALING
;
// 1D block scaling, row-wise or colum-wise
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
const
int
m
=
input
->
has_data
()
?
input
->
scale_inv
.
shape
[
0
]
:
input
->
columnwise_scale_inv
.
shape
[
1
];
const
int
k
=
input
->
has_data
()
?
input
->
scale_inv
.
shape
[
1
]
:
input
->
columnwise_scale_inv
.
shape
[
0
];
constexpr
int
SF_TILE_DIM_M
=
128
;
constexpr
int
SF_TILE_DIM_K
=
4
;
NVTE_CHECK
(
m
%
SF_TILE_DIM_M
==
0
,
"Input should be padded in M/N dimension!"
);
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
if
(
output
->
has_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
scale_inv
.
shape
.
begin
(),
output
->
scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.scale_inv size is not equal to Output.scale_inv size!"
);
}
if
(
output
->
has_columnwise_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
columnwise_scale_inv
.
shape
.
begin
(),
output
->
columnwise_scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!"
);
int
m
,
k
;
if
(
input
->
has_data
())
{
m
=
input
->
scale_inv
.
shape
[
0
];
k
=
input
->
scale_inv
.
shape
[
1
];
}
else
{
if
(
nvfp4
)
{
m
=
input
->
columnwise_scale_inv
.
shape
[
0
];
k
=
input
->
columnwise_scale_inv
.
shape
[
1
];
}
else
{
m
=
input
->
columnwise_scale_inv
.
shape
[
1
];
k
=
input
->
columnwise_scale_inv
.
shape
[
0
];
}
}
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
constexpr
int
SF_TILE_DIM_M
=
128
;
constexpr
int
SF_TILE_DIM_K
=
4
;
NVTE_CHECK
(
m
%
SF_TILE_DIM_M
==
0
,
"Input should be padded in M/N dimension!"
);
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
if
(
output
->
has_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
scale_inv
.
shape
.
begin
(),
output
->
scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.scale_inv size is not equal to Output.scale_inv size!"
);
}
if
(
output
->
has_columnwise_data
())
{
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
->
columnwise_scale_inv
.
shape
.
begin
(),
output
->
columnwise_scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!"
);
}
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const
bool
rowwise_swizzle
=
input
->
has_data
()
||
nvfp4
;
const
bool
columnwise_swizzle
=
input
->
has_columnwise_data
()
&&
!
nvfp4
;
dim3
block_size
(
TB_DIM
,
TB_DIM
);
if
(
input
->
has_data
())
{
int
vec_load_size
=
(
num_tiles_k
-
1
)
%
4
+
1
;
/* there is no int3 and misaligned if using int4/int2 */
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_first_dim
();
const
int
original_K
=
input
->
flat_last_dim
()
/
MXFP8_BLOCK_SIZE
;
switch
(
vec_load_size
)
{
if
(
rowwise_swizzle
)
{
int
vec_load_size
=
(
num_tiles_k
-
1
)
%
4
+
1
;
/* there is no int3 and misaligned if using int4/int2 */
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
original_M
,
original_K
;
void
*
input_scale_inv_ptr
,
*
output_scale_inv_ptr
;
if
(
!
nvfp4
||
input
->
has_data
())
{
int
block_scale_size
=
nvfp4
?
NVFP4_BLOCK_SIZE
:
MXFP8_BLOCK_SIZE
;
original_M
=
input
->
flat_first_dim
();
original_K
=
input
->
flat_last_dim
()
/
block_scale_size
;
input_scale_inv_ptr
=
input
->
scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
scale_inv
.
dptr
;
}
else
{
original_M
=
input
->
flat_last_dim
();
original_K
=
input
->
flat_first_dim
()
/
NVFP4_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
columnwise_scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
columnwise_scale_inv
.
dptr
;
}
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#else
case
4
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
4
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
if
(
input
->
has_columnwise_data
())
{
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
/* no int3 and misaligned if using int4/int2 */
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_last_dim
();
const
int
original_K
=
input
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
switch
(
vec_load_size
)
{
}
if
(
columnwise_swizzle
)
{
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
/* no int3 and misaligned if using int4/int2 */
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_last_dim
();
const
int
original_K
=
input
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK
(
!
nvfp4
,
"NVFP4 shouldn't end up here because it only needs rowwise swizzle"
);
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#else
case
4
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
4
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
// 2D block scaling
}
else
{
NVTE_ERROR
(
"Not implemented for scaling_mode "
+
to_string
(
input
->
scaling_mode
)
+
", trans."
);
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
...
...
@@ -650,6 +680,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// TODO(nvfp4): Add NVFP4 support.
void
multi_tensor_swizzle_scaling_factors
(
const
std
::
vector
<
Tensor
*>&
input
,
std
::
vector
<
Tensor
*>&
output
,
cudaStream_t
stream
)
{
auto
num_tensors
=
input
.
size
();
...
...
@@ -776,7 +808,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
* WIP (Phuong):
* - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling.
*/
*/
void
nvte_swizzle_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swizzle_scaling_factors
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/swizzle/swizzle_block_scaling.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>
#include <cstdint>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
namespace
{
constexpr
uint32_t
WARP_SIZE
=
32
;
}
// namespace
namespace
swizzle_kernel_1d
{
constexpr
uint32_t
WARPS_X_PER_TB
=
2
;
// configurable
constexpr
uint32_t
WARPS_Y_PER_TB
=
2
;
// configurable
// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where
// each thread stores a single row (of four bytes).
// Example:
// lane0.row = 0x00010203
// lane1.row = 0x04050607
// lane2.row = 0x08090a0b
// lane3.row = 0x0c0d0e0f
// Becomes:
// lane0.row = 0x0004080c
// lane1.row = 0x0105090d
// lane2.row = 0x02060a0e
// lane3.row = 0x03070b0f
uint32_t
__device__
__forceinline__
transpose_4x4_byte_matrix
(
const
uint32_t
row
,
const
uint32_t
lane
,
const
uint32_t
active_mask
)
{
using
cu
=
const
uint32_t
;
// Threads operate in groups of 4, and each thread stores 4 bytes at a time.
// The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes
// until we have transposed the 4x4 matrix.
cu
m_0123_4567_89ab_cdef
=
row
;
cu
m_4567_0123_cdef_89ab
=
__shfl_xor_sync
(
active_mask
,
m_0123_4567_89ab_cdef
,
1
,
4
);
cu
m_0426_4062_8cae_c8ea
=
__byte_perm
(
m_0123_4567_89ab_cdef
,
m_4567_0123_cdef_89ab
,
0x6240
);
cu
m_5173_1537_d9fb_9dbf
=
__byte_perm
(
m_0123_4567_89ab_cdef
,
m_4567_0123_cdef_89ab
,
0x3715
);
cu
m_0426_1537_8cae_9dbf
=
(
lane
&
1
)
?
m_5173_1537_d9fb_9dbf
:
m_0426_4062_8cae_c8ea
;
cu
m_8cae_9dbf_0426_1537
=
__shfl_xor_sync
(
active_mask
,
m_0426_1537_8cae_9dbf
,
2
,
4
);
cu
m_048c_159d_8c04_9d15
=
__byte_perm
(
m_0426_1537_8cae_9dbf
,
m_8cae_9dbf_0426_1537
,
0x5410
);
cu
m_ae26_bf37_26ae_37bf
=
__byte_perm
(
m_0426_1537_8cae_9dbf
,
m_8cae_9dbf_0426_1537
,
0x3276
);
cu
m_048c_159d_26ae_37bf
=
(
lane
&
2
)
?
m_ae26_bf37_26ae_37bf
:
m_048c_159d_8c04_9d15
;
return
m_048c_159d_26ae_37bf
;
}
// Expands a uint32_t to a uint4 by duplicating each byte four times.
// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404}
uint4
__device__
__forceinline__
broadcast_uint32_t_to_uint4
(
uint32_t
x
)
{
return
{
__byte_perm
(
x
,
0
,
0x0000
),
__byte_perm
(
x
,
0
,
0x1111
),
__byte_perm
(
x
,
0
,
0x2222
),
__byte_perm
(
x
,
0
,
0x3333
)};
}
// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data
// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors.
struct
no_oob_tag_t
{};
constexpr
no_oob_tag_t
NO_OOB_TAG
;
template
<
typename
OOBT
>
void
__global__
__launch_bounds__
(
WARPS_X_PER_TB
*
WARPS_Y_PER_TB
*
WARP_SIZE
)
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel
(
const
void
*
__restrict__
const
in
,
void
*
__restrict__
const
out
,
const
uint32_t
tiles_x
,
const
uint32_t
tiles_y
,
const
uint32_t
in_y_stride
,
const
uint32_t
out_y_stride
,
OOBT
first_oob
)
{
// resolve kernel variant
constexpr
bool
no_oob
=
std
::
is_same_v
<
OOBT
,
no_oob_tag_t
>
;
static_assert
(
no_oob
||
std
::
is_same_v
<
OOBT
,
uint32_t
>
);
// load thread indices
const
uint32_t
lane
=
threadIdx
.
x
;
__builtin_assume
(
lane
<
WARP_SIZE
);
const
uint32_t
warp_x
=
threadIdx
.
z
;
__builtin_assume
(
warp_x
<
WARPS_X_PER_TB
);
const
uint32_t
warp_y
=
threadIdx
.
y
;
__builtin_assume
(
warp_y
<
WARPS_Y_PER_TB
);
// compute tile indices
const
uint32_t
out_tile_y
=
blockIdx
.
y
*
WARPS_Y_PER_TB
+
warp_y
;
const
uint32_t
out_tile_x
=
blockIdx
.
x
*
WARPS_X_PER_TB
+
warp_x
;
const
uint32_t
in_tile_y
=
out_tile_x
;
const
uint32_t
in_tile_x
=
out_tile_y
;
// bounds check; uniform branch
if
(
out_tile_y
>=
tiles_y
||
out_tile_x
>=
tiles_x
)
{
return
;
}
// calculate this warp's input base pointer
constexpr
uint32_t
in_x_stride
=
WARP_SIZE
*
sizeof
(
uint4
);
const
void
*
const
warp_src
=
in
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
;
// load scaling factors for this lane's initial four 1x128 tiles
uint4
sf
;
if
constexpr
(
no_oob
)
{
sf
=
reinterpret_cast
<
const
uint4
*>
(
warp_src
)[
lane
];
}
else
{
if
((
out_tile_y
<
tiles_y
-
1
)
||
lane
<
first_oob
)
{
sf
=
reinterpret_cast
<
const
uint4
*>
(
warp_src
)[
lane
];
}
else
{
sf
=
uint4
{
0
,
0
,
0
,
0
};
}
}
// pack the exponent bits of the scaling factors
uint32_t
packed_exponents
=
(
sf
.
x
>>
23
)
|
(
sf
.
y
>>
15
)
|
(
sf
.
z
>>
7
)
|
(
sf
.
w
<<
1
);
// partially swizzle the scaling factors
constexpr
uint32_t
ACTIVE_MASK
=
0xFFFFFFFF
;
// no divergent branches
const
uint32_t
lane_load_idx
=
(
lane
%
4
)
*
8
+
(
lane
/
4
);
packed_exponents
=
__shfl_sync
(
ACTIVE_MASK
,
packed_exponents
,
lane_load_idx
);
// transpose 4x4 matrices of scaling factors
packed_exponents
=
transpose_4x4_byte_matrix
(
packed_exponents
,
lane
%
4
,
ACTIVE_MASK
);
// broadcast the scaling factors for sixteen 1x32 tiles
sf
=
broadcast_uint32_t_to_uint4
(
packed_exponents
);
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr
uint32_t
out_x_stride
=
512
;
void
*
const
warp_dst
=
out
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
;
reinterpret_cast
<
uint4
*>
(
warp_dst
)[
lane
]
=
sf
;
}
void
launch_kernel
(
const
void
*
const
in
,
void
*
const
out
,
uint32_t
data_rows
,
uint32_t
data_cols
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
is_aligned_ptr
(
in
,
alignof
(
uint4
)),
"Input scaling factor pointer must be aligned to "
,
alignof
(
uint4
),
" bytes"
);
NVTE_CHECK
(
is_aligned_ptr
(
out
,
alignof
(
uint4
)),
"Output scaling factor pointer must be aligned to "
,
alignof
(
uint4
),
" bytes"
);
NVTE_CHECK
(
data_rows
%
4
==
0
,
"Input tensor must not have any padding scaling factors"
);
const
uint32_t
tiles_x
=
DIVUP
(
data_cols
,
128u
);
const
uint32_t
tiles_y
=
DIVUP
(
data_rows
,
128u
);
const
dim3
grid_dim
{
DIVUP
(
tiles_x
,
WARPS_X_PER_TB
),
DIVUP
(
tiles_y
,
WARPS_Y_PER_TB
),
1
};
const
dim3
block_dim
{
WARP_SIZE
,
WARPS_Y_PER_TB
,
WARPS_X_PER_TB
};
// Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales
// and a 128x4 tile in the output scales. The input scales are in transposed order.
const
uint32_t
input_scale_inv_cols
=
DIVUP
(
data_rows
,
4u
)
*
4
;
const
uint32_t
output_scale_inv_cols
=
tiles_x
*
128
*
4
;
const
uint32_t
in_y_stride
=
input_scale_inv_cols
*
sizeof
(
float
);
const
uint32_t
out_y_stride
=
output_scale_inv_cols
*
sizeof
(
uint8_t
);
const
uint32_t
first_oob
=
(
input_scale_inv_cols
%
128
)
/
4
;
if
(
first_oob
==
0
)
{
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
in
,
out
,
tiles_x
,
tiles_y
,
in_y_stride
,
out_y_stride
,
NO_OOB_TAG
);
}
else
{
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
in
,
out
,
tiles_x
,
tiles_y
,
in_y_stride
,
out_y_stride
,
first_oob
);
}
}
}
// namespace swizzle_kernel_1d
namespace
swizzle_kernel_2d
{
constexpr
uint32_t
WARPS_X_PER_TB
=
2
;
// configurable
constexpr
uint32_t
WARPS_Y_PER_TB
=
2
;
// configurable
void
__global__
__launch_bounds__
(
WARPS_X_PER_TB
*
WARPS_Y_PER_TB
*
WARP_SIZE
)
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel
(
const
void
*
__restrict__
const
in
,
void
*
__restrict__
const
out
,
const
uint32_t
tiles_x
,
const
uint32_t
tiles_y
,
const
uint32_t
in_y_stride
,
const
uint32_t
out_y_stride
)
{
// load thread indices
const
uint32_t
lane
=
threadIdx
.
x
;
__builtin_assume
(
lane
<
WARP_SIZE
);
const
uint32_t
warp_x
=
threadIdx
.
z
;
__builtin_assume
(
warp_x
<
WARPS_X_PER_TB
);
const
uint32_t
warp_y
=
threadIdx
.
y
;
__builtin_assume
(
warp_y
<
WARPS_Y_PER_TB
);
// compute tile indices
const
uint32_t
out_tile_y
=
blockIdx
.
y
*
WARPS_Y_PER_TB
+
warp_y
;
const
uint32_t
out_tile_x
=
blockIdx
.
x
*
WARPS_X_PER_TB
+
warp_x
;
const
uint32_t
in_tile_y
=
out_tile_y
;
const
uint32_t
in_tile_x
=
out_tile_x
;
// bounds check; uniform branch
if
(
out_tile_y
>=
tiles_y
||
out_tile_x
>=
tiles_x
)
{
return
;
}
// calculate this warp's input base pointer
constexpr
uint32_t
in_x_stride
=
sizeof
(
float
);
const
void
*
const
warp_src
=
in
+
in_tile_y
*
in_y_stride
+
in_tile_x
*
in_x_stride
;
// load scaling factor for this warp's 128x128 tile
uint32_t
sf
=
*
reinterpret_cast
<
const
uint32_t
*>
(
warp_src
);
// broadcast it to four scaling factors for 1x32 tiles
sf
=
(
sf
<<
1
)
|
(
sf
>>
7
);
sf
=
sf
|
(
sf
>>
16
);
// broadcast it to sixteen scaling factors for 1x32 tiles
const
uint4
sf4
{
sf
,
sf
,
sf
,
sf
};
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr
uint32_t
out_x_stride
=
512
;
void
*
const
warp_dst
=
out
+
out_tile_y
*
out_y_stride
+
out_tile_x
*
out_x_stride
;
reinterpret_cast
<
uint4
*>
(
warp_dst
)[
lane
]
=
sf4
;
}
void
launch_kernel
(
const
void
*
const
in
,
void
*
const
out
,
uint32_t
data_rows
,
uint32_t
data_cols
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
is_aligned_ptr
(
in
,
alignof
(
float
)),
"Input scaling factor pointer must be aligned to "
,
alignof
(
float
),
" bytes"
);
NVTE_CHECK
(
is_aligned_ptr
(
out
,
alignof
(
uint4
)),
"Output scaling factor pointer must be aligned to "
,
alignof
(
uint4
),
" bytes"
);
const
uint32_t
tiles_x
=
DIVUP
(
data_cols
,
128u
);
const
uint32_t
tiles_y
=
DIVUP
(
data_rows
,
128u
);
const
dim3
grid_dim
{
DIVUP
(
tiles_x
,
WARPS_X_PER_TB
),
DIVUP
(
tiles_y
,
WARPS_Y_PER_TB
),
1
};
const
dim3
block_dim
{
WARP_SIZE
,
WARPS_Y_PER_TB
,
WARPS_X_PER_TB
};
// Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales
// and a 128x4 tile in the output scales.
const
uint32_t
input_scale_inv_cols
=
DIVUP
(
data_cols
,
512u
)
*
4
;
const
uint32_t
output_scale_inv_cols
=
tiles_x
*
128
*
4
;
const
uint32_t
in_y_stride
=
input_scale_inv_cols
*
sizeof
(
float
);
const
uint32_t
out_y_stride
=
output_scale_inv_cols
*
sizeof
(
uint8_t
);
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
in
,
out
,
tiles_x
,
tiles_y
,
in_y_stride
,
out_y_stride
);
}
}
// namespace swizzle_kernel_2d
void
swizzle_block_scaling_to_mxfp8_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
// Do nothing if tensor is empty
if
(
input
->
data
.
numel
()
==
0
)
{
return
;
}
CheckInputTensor
(
*
input
,
"block_scaling_scaling_factor_input"
);
CheckInputTensor
(
*
output
,
"mxfp8_scaling_factor_output"
);
const
NVTEScalingMode
scaling_mode
=
input
->
scaling_mode
;
NVTE_CHECK
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
,
"Input tensor must be a block scaling tensor"
);
NVTE_CHECK
(
output
->
scaling_mode
==
NVTE_MXFP8_1D_SCALING
,
"Output tensor must be an mxfp8 tensor"
);
NVTE_CHECK
(
input
->
data
.
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
input
->
data
.
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
,
"Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"
);
NVTE_CHECK
(
output
->
data
.
dtype
==
input
->
data
.
dtype
,
"Output data must have the same dtype as input data"
);
NVTE_CHECK
(
input
->
scale_inv
.
dtype
==
DType
::
kFloat32
,
"Input must have FP32 scaling factors"
);
NVTE_CHECK
(
output
->
scale_inv
.
dtype
==
DType
::
kFloat8E8M0
,
"Output must have E8M0 scaling factors"
);
NVTE_CHECK
(
input
->
data
.
dptr
!=
nullptr
,
"Input must have rowwise data"
);
NVTE_CHECK
(
output
->
data
.
dptr
==
input
->
data
.
dptr
,
"Output must share data with input"
);
NVTE_CHECK
(
input
->
scale_inv
.
dptr
!=
nullptr
,
"Input must have rowwise scaling factors"
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Output must have rowwise scaling factors"
);
NVTE_CHECK
(
input
->
data
.
shape
.
size
()
==
2
,
"Input data must be a matrix"
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
->
data
.
shape
,
"Output data must have the same shape as input data"
);
NVTE_CHECK
(
input
->
scale_inv
.
shape
.
size
()
==
2
,
"Input scaling factors must be a matrix"
);
NVTE_CHECK
(
output
->
scale_inv
.
shape
.
size
()
==
2
,
"Output scaling factors must be a matrix"
);
const
size_t
data_rows
=
input
->
data
.
shape
[
0
];
const
size_t
data_cols
=
input
->
data
.
shape
[
1
];
const
size_t
input_scale_inv_rows
=
input
->
scale_inv
.
shape
[
0
];
const
size_t
input_scale_inv_cols
=
input
->
scale_inv
.
shape
[
1
];
const
size_t
output_scale_inv_rows
=
output
->
scale_inv
.
shape
[
0
];
const
size_t
output_scale_inv_cols
=
output
->
scale_inv
.
shape
[
1
];
NVTE_CHECK
(
output_scale_inv_rows
==
DIVUP
<
size_t
>
(
data_rows
,
128
)
*
128
,
"Expected the output scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_rows
,
128
)
*
128
,
" rows, but it has "
,
output_scale_inv_rows
,
" rows instead."
);
NVTE_CHECK
(
output_scale_inv_cols
==
DIVUP
<
size_t
>
(
data_cols
,
128
)
*
4
,
"Expected the output scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_cols
,
128
)
*
4
,
" columns, but it has "
,
output_scale_inv_cols
,
" columns instead."
);
if
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
)
{
NVTE_CHECK
(
input_scale_inv_rows
==
DIVUP
<
size_t
>
(
data_cols
,
128
),
"Expected the input scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_cols
,
128
),
" rows, but it has "
,
input_scale_inv_rows
,
" rows instead."
);
NVTE_CHECK
(
input_scale_inv_cols
==
DIVUP
<
size_t
>
(
data_rows
,
4
)
*
4
,
"Expected the input scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_rows
,
4
)
*
4
,
" columns, but it has "
,
input_scale_inv_cols
,
" columns instead."
);
swizzle_kernel_1d
::
launch_kernel
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
data_rows
,
data_cols
,
stream
);
}
else
{
// scaling_mode == NVTE_BLOCK_SCALING_2D
NVTE_CHECK
(
input_scale_inv_rows
==
DIVUP
<
size_t
>
(
data_rows
,
128
),
"Expected the input scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_rows
,
128
),
" rows, but it has "
,
input_scale_inv_rows
,
" rows instead."
);
NVTE_CHECK
(
input_scale_inv_cols
==
DIVUP
<
size_t
>
(
data_cols
,
512
)
*
4
,
"Expected the input scaling factor matrix to have "
,
DIVUP
<
size_t
>
(
data_cols
,
512
)
*
4
,
" columns, but it has "
,
input_scale_inv_cols
,
" columns instead."
);
swizzle_kernel_2d
::
launch_kernel
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
data_rows
,
data_cols
,
stream
);
}
}
}
// namespace transformer_engine
void
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors
);
using
namespace
transformer_engine
;
swizzle_block_scaling_to_mxfp8_scaling_factors
(
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
}
transformer_engine/common/transformer_engine.cpp
View file @
063ef88d
...
...
@@ -11,6 +11,7 @@
#include <cstring>
#include <iostream>
#include <mutex>
#include <utility>
#include "common.h"
#include "common/util/cuda_runtime.h"
...
...
@@ -67,8 +68,12 @@ std::string to_string(const NVTEScalingMode &mode) {
return
"NVTE_DELAYED_TENSOR_SCALING"
;
case
NVTE_MXFP8_1D_SCALING
:
return
"NVTE_MXFP8_1D_SCALING"
;
case
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
:
return
"NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"
;
case
NVTE_BLOCK_SCALING_1D
:
return
"NVTE_BLOCK_SCALING_1D"
;
case
NVTE_BLOCK_SCALING_2D
:
return
"NVTE_BLOCK_SCALING_2D"
;
case
NVTE_NVFP4_1D_SCALING
:
return
"NVTE_NVFP4_1D_SCALING"
;
case
NVTE_INVALID_SCALING
:
return
"NVTE_INVALID_SCALING"
;
}
...
...
@@ -98,12 +103,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
t
.
scaling_mode
==
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
)
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
// Need (4, 128) alignment even for e8 scaling factor
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
size_t
expected_x
,
expected_y
,
alignment
;
const
size_t
block_size_rowwise
=
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
?
32
:
16
;
const
size_t
block_size_rowwise
=
32
;
const
size_t
block_size_colwise
=
32
;
if
(
t
.
has_data
())
{
...
...
@@ -114,6 +118,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
block_size_rowwise
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
NVTE_CHECK
(
t
.
scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected "
,
expected
,
", got "
,
...
...
@@ -126,11 +131,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
alignment
;
alignment
=
block_alignment
[
0
];
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected "
,
expected
,
", got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
if
(
t
.
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
{
if
(
t
.
has_data
())
{
const
size_t
expected_y
=
DIVUP_TO_MULTIPLE
(
t
.
flat_first_dim
(),
128
);
const
size_t
expected_x
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
t
.
flat_last_dim
(),
16lu
),
4
);
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_y
,
expected_x
};
NVTE_CHECK
(
t
.
scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected "
,
expected
,
", got "
,
t
.
scale_inv
.
shape
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
const
size_t
expected_y
=
DIVUP_TO_MULTIPLE
(
t
.
flat_last_dim
(),
128
);
const
size_t
expected_x
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
t
.
flat_first_dim
(),
16lu
),
4
);
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_y
,
expected_x
};
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected "
,
expected
,
", got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
}
}
...
...
@@ -158,6 +181,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
"(expected Float32 or Byte, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
if
(
is_fp4_dtype
(
type
))
{
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got "
,
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP8 scaling factor input "
,
name
,
"_columnwise_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
t
.
amax
.
dptr
==
nullptr
,
"Amax is not supported for non-FP8 input "
,
name
);
...
...
@@ -199,10 +242,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
"(expected Float32 or Float8E8M0, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
if
(
is_fp4_dtype
(
type
))
{
// FP4 output needs to have the scale_inv
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse has invalid dtype "
"(expected Float8E4M3, got "
,
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float8E4M3, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 output "
,
name
);
//
Note: amax is supported for non-FP8 output as it can be fused into the computation
//
and later used for quantization with no need to compute it separately
//
Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
//
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
...
...
@@ -507,6 +569,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
case
kNVTEColumnwiseScaleInv
:
t
->
columnwise_scale_inv
=
*
param
;
break
;
case
kNVTEColumnwiseAmax
:
t
->
columnwise_amax
=
*
param
;
break
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
}
...
...
@@ -530,6 +595,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
return
t
.
scale_inv
;
case
kNVTEColumnwiseScaleInv
:
return
t
.
columnwise_scale_inv
;
case
kNVTEColumnwiseAmax
:
return
t
.
columnwise_amax
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
}
...
...
@@ -645,6 +712,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
&
config_
.
float8_block_scale_tensor_format
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigRNGState
:
std
::
memcpy
(
&
config_
.
rng_state
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigNVFP42DQuantization
:
std
::
memcpy
(
&
config_
.
nvfp4_2d_quantization
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigStochasticRounding
:
std
::
memcpy
(
&
config_
.
stochastic_rounding
,
buf
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment