Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
626 additions
and
124 deletions
+626
-124
transformer_engine/common/include/transformer_engine/dropout.h
...former_engine/common/include/transformer_engine/dropout.h
+1
-1
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+4
-6
transformer_engine/common/include/transformer_engine/fused_rope.h
...mer_engine/common/include/transformer_engine/fused_rope.h
+1
-1
transformer_engine/common/include/transformer_engine/fused_router.h
...r_engine/common/include/transformer_engine/fused_router.h
+1
-1
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+25
-20
transformer_engine/common/include/transformer_engine/hadamard_transform.h
...ne/common/include/transformer_engine/hadamard_transform.h
+64
-1
transformer_engine/common/include/transformer_engine/multi_stream.h
...r_engine/common/include/transformer_engine/multi_stream.h
+1
-1
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+32
-1
transformer_engine/common/include/transformer_engine/normalization.h
..._engine/common/include/transformer_engine/normalization.h
+12
-4
transformer_engine/common/include/transformer_engine/padding.h
...former_engine/common/include/transformer_engine/padding.h
+1
-1
transformer_engine/common/include/transformer_engine/permutation.h
...er_engine/common/include/transformer_engine/permutation.h
+1
-1
transformer_engine/common/include/transformer_engine/recipe.h
...sformer_engine/common/include/transformer_engine/recipe.h
+186
-3
transformer_engine/common/include/transformer_engine/softmax.h
...former_engine/common/include/transformer_engine/softmax.h
+1
-1
transformer_engine/common/include/transformer_engine/swizzle.h
...former_engine/common/include/transformer_engine/swizzle.h
+5
-5
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+233
-67
transformer_engine/common/include/transformer_engine/transpose.h
...rmer_engine/common/include/transformer_engine/transpose.h
+6
-6
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+1
-1
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+49
-1
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+1
-1
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
...sformer_engine/common/multi_tensor/multi_tensor_apply.cuh
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/common/include/transformer_engine/dropout.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter
in
shape [H].
* where alpha is a learnable parameter
of
shape [H].
*/
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
...
...
@@ -409,7 +409,6 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
...
...
@@ -673,7 +672,7 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
,
NVTETensor
workspace
,
size_t
len
,
uint32_t
nvte_get_runtime_num_segments
(
NVTETensor
cu_seqlen
s
,
NVTETensor
workspace
,
size_t
len
,
cudaStream_t
stream
);
/*! \brief Set the seed and offset for RNG state.
...
...
@@ -830,8 +829,7 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] t Packed sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_convert_bshd_to_thd
(
NVTETensor
tensor
,
NVTETensor
cu_seqlens
,
NVTETensor
new_tensor
,
...
...
transformer_engine/common/include/transformer_engine/fused_rope.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/fused_router.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -57,24 +57,24 @@ NVTEMatmulConfig nvte_create_matmul_config();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value
. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
* \param[in]
config
Matrix multiplication configuration.
* \param[in]
attr
Option type.
* \param[out] buf
Memory address to write option value
to.
*
Ignored if
NULL.
* \param[in]
size_in_bytes Size of buf.
* \param[out] size_written
Number of bytes that have been written to
*
buf. If buf is NULL, then the number of
*
bytes that would have been written.
*/
void
nvte_get_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in]
attr
Option type.
* \param[
out] buf
Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
* \param[in
/out
] config
Matrix multiplication configuration.
* \param[in]
attr
Option type.
* \param[
in] buf
Memory address to read option value
from
.
* \param[in]
size_in_bytes Size of buf.
*/
void
nvte_set_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
);
...
...
@@ -280,9 +280,11 @@ class MatmulConfigWrapper {
MatmulConfigWrapper
(
const
MatmulConfigWrapper
&
)
=
delete
;
MatmulConfigWrapper
&
operator
=
(
const
MatmulConfigWrapper
&
)
=
delete
;
/*! \brief Move constructor. */
MatmulConfigWrapper
(
MatmulConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
/*! \brief Move-assignment operator. */
MatmulConfigWrapper
&
operator
=
(
MatmulConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_matmul_config
(
config_
);
...
...
@@ -319,14 +321,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void
set_with_gelu_epilogue
(
bool
with_gelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithGELUE
pilogue
,
&
with_gelu_e
pilogue
,
sizeof
(
boo
l
));
const
auto
val
=
static_cast
<
uint8_t
>
(
with_gelu_e
pilogue
);
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithGELUE
pilogue
,
&
val
,
sizeof
(
va
l
));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void
set_with_dgelu_epilogue
(
bool
with_dgelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithDGELUEpilogue
,
&
with_dgelu_epilogue
,
sizeof
(
bool
));
const
auto
val
=
static_cast
<
uint8_t
>
(
with_dgelu_epilogue
);
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithDGELUEpilogue
,
&
val
,
sizeof
(
val
));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
...
...
@@ -337,13 +340,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void
set_use_split_accumulator
(
bool
use_split_accumulator
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigUseSplitAccumulator
,
&
use_split_accumulator
,
sizeof
(
bool
));
const
auto
val
=
static_cast
<
uint8_t
>
(
use_split_accumulator
);
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigUseSplitAccumulator
,
&
val
,
sizeof
(
val
));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void
set_sm_count
(
int
sm_count
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigSMCount
,
&
sm_count
,
sizeof
(
int
));
const
auto
val
=
static_cast
<
int32_t
>
(
sm_count
);
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigSMCount
,
&
val
,
sizeof
(
val
));
}
private:
...
...
transformer_engine/common/include/transformer_engine/hadamard_transform.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -61,6 +61,69 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
);
/*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split.
*
* This function is experimental and the API is not stable.
*
* This is intended for quantizing to NVFP4 with random Hadamard
* transforms (RHT). For each tensor split, compute the maximum
* absolute value (amax) and populate the row-wise amax of the
* corresponding output tensor. Also, compute the amax after a
* transposed RHT and populate the column-wise amax of the
* corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of NVFP4 output tensors. Only the row-wise and
* column-wise amaxes are updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_cast_fusion_columnwise
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
NVTETensor
hadamard_matrix
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
);
/*!
* \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_cast_fusion
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
NVTETensor
hadamard_matrix
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/multi_stream.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/multi_tensor.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -265,6 +265,37 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
cudaStream_t
stream
);
/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_multi_tensor_compute_scale_inv_e8m0_cuda
(
int
chunk_size
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
cudaStream_t
stream
);
/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
*
* For each tensor split, compute the maximum absolute value (amax)
* and populate the amax of the corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of output tensors. Only the amax is updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/normalization.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -163,11 +163,16 @@ void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETen
NVTETensor
dgamma
,
NVTETensor
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
);
/*! \brief
Help
er to enable cuDNN backend for normalization
/*! \brief
Set wheth
er to enable cuDNN backend for normalization
forward.
*
* \param[in]
bool Enable if True
* \param[in]
enable Whether to enable cuDNN backend.
*/
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
);
/*! \brief Set whether to enable cuDNN backend for normalization backward.
*
* \param[in] enable Whether to enable cuDNN backend.
*/
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
);
/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
...
...
@@ -176,11 +181,14 @@ void nvte_enable_cudnn_norm_bwd(bool enable);
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in]
bool Enable if True
* \param[in]
enable Whether to enable zero-centered gamma.
*/
void
nvte_enable_zero_centered_gamma_in_weight_dtype
(
bool
enable
);
#ifdef __cplusplus
/*! \brief Normalization function type */
enum
class
NVTE_Norm_Type
{
LayerNorm
,
RMSNorm
};
#endif
#ifdef __cplusplus
}
// extern "C"
...
...
transformer_engine/common/include/transformer_engine/padding.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/permutation.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/recipe.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -23,7 +23,7 @@ extern "C" {
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-
\text{
margin}
}
\
text{max}_\text
{fp8\_dtype} @f$.
* @f$ 2^{-margin} \
max_
{fp8\_dtype} @f$.
*
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
...
...
@@ -54,7 +54,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-
\text{
margin}
}
\
text{max}_\text
{fp8\_dtype} @f$.
* @f$ 2^{-margin} \
max_
{fp8\_dtype} @f$.
*
* \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction.
* Shape: [num_scales * num_tensors]
...
...
@@ -113,17 +113,200 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
void
nvte_compute_scale_from_amax
(
NVTETensor
output
,
const
NVTEQuantizationConfig
config
,
cudaStream_t
stream
);
/*! \brief Compute partial amax for FP8 blockwise scaling.
*
* This function computes the maximum absolute values for each block of the original tensor.
* `inp` contains a continuous segment from the flattened original tensor. For each block,
* if it overlaps with the range [start_offset, start_offset+inp.length), the amax is
* computed from inp; otherwise, the amax is set to 0.
*
* Example: Original tensor (logically 512x512) divided into 16 blocks of size 128x128.
* `inp` contains continuous elements starting from position start_offset
* in the flattened original tensor.
*
* Logical view - Original Tensor (e.g., 512x512) divided into 16 blocks of size 128x128:
* ┌─────────┬─────────┬─────────┬─────────┐
* │ Block0 │ Block1 │ Block2 │ Block3 │ Each block: 128x128
* │ 128x128 │ 128x128 │ 128x128 │ 128x128 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block4 │ Block5 │ Block6 │ Block7 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block8 │ Block9 │ Block10 │ Block11 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block12 │ Block13 │ Block14 │ Block15 │
* └─────────┴─────────┴─────────┴─────────┘
*
* Physical view - Flattened in row-major order:
* ┌────────────────────────────────────────────────────────────────┐
* │[0...128][128...256][256...384][384...512]...[261632...262143] │
* └────────────────────────────────────────────────────────────────┘
* ^ ^
* start_offset start_offset + inp.length
*
* For each 128x128 block, compute amax:
* - If the block overlaps with [start_offset, start_offset+inp.length), compute amax
* - If the block is completely outside this range, set amax = 0
*
* amax output (one value per 128x128 block), block 1 and block 2 are non-zero because they
* overlap with the [start_offset, start_offset+inp.length) range:
* ┌───────┬───────┬───────┬───────┐
* │ 0 │ amax │ amax │ 0 │ Block0-3
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block4-7
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block8-11
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block12-15
* └───────┴───────┴───────┴───────┘
*
* \param[in] inp Input tensor (continuous slice of flattened original tensor).
* \param[in,out] amax Output tensor for maximum absolute values per block.
* \param[in] h Height dimension of the logical tensor.
* \param[in] w Width dimension of the logical tensor.
* \param[in] amax_stride_h Stride in height dimension for amax tensor.
* \param[in] amax_stride_w Stride in width dimension for amax tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] block_len Length of a quantization block to process.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fp8_block_scaling_compute_partial_amax
(
const
NVTETensor
inp
,
NVTETensor
amax
,
size_t
h
,
size_t
w
,
size_t
amax_stride_h
,
size_t
amax_stride_w
,
size_t
start_offset
,
size_t
block_len
,
cudaStream_t
stream
);
/*! \brief Perform partial FP8 casting with blockwise scaling.
*
* This function casts the input tensor to FP8 format using blockwise scaling factors.
* `inp` contains a continuous segment from the flattened original tensor.
*
* \param[in] inp Input tensor.
* \param[out] out Output tensor in FP8 format.
* \param[in] scale Scaling factors per block.
* \param[in] h Height dimension of the tensor.
* \param[in] w Width dimension of the tensor.
* \param[in] scale_stride_h Stride in height dimension for scale tensor.
* \param[in] scale_stride_w Stride in width dimension for scale tensor.
* \param[in] start_offset Starting offset for partial computation.
* \param[in] block_len Length of the block to process.
* \param[in] out_dtype Output FP8 datatype.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fp8_block_scaling_partial_cast
(
const
NVTETensor
inp
,
NVTETensor
out
,
const
NVTETensor
scale
,
size_t
h
,
size_t
w
,
size_t
scale_stride_h
,
size_t
scale_stride_w
,
size_t
start_offset
,
size_t
block_len
,
const
NVTEDType
out_dtype
,
cudaStream_t
stream
);
/*! \brief Compute partial amax for MXFP8 scaling.
*
* This function computes the maximum absolute values along both row and column dimensions.
* input contains a continuous segment from the flattened original tensor. For each row/column
* block, if it overlaps with the range starting from start_offset, the amax is computed from
* `input`; otherwise, the amax is set to 0.
*
* Example: Original tensor (64 rows x 64 cols).
* Rowwise amax granularity: 1x32 (each row divided into 2 blocks)
* Columnwise amax granularity: 32x1 (each column divided into 2 blocks)
* input contains a continuous segment starting from start_offset.
*
* Logical view - Original Tensor (64x64) with 1x32 and 32x1 blocks:
*
* Rowwise blocks (1x32): Each row has 2 blocks
* ┌──────────────┬──────────────┐
* row0 │ Block_r0_0 │ Block_r0_1 │ (cols 0-31, 32-63)
* ├──────────────┼──────────────┤
* row1 │ Block_r1_0 │ Block_r1_1 │
* ├──────────────┼──────────────┤
* ... │ ... │ ... │
* ├──────────────┼──────────────┤
* row63│ Block_r63_0 │ Block_r63_1 │
* └──────────────┴──────────────┘
*
* Columnwise blocks (32x1): Each column has 2 blocks
* ┌───┬───┬─────┬───┬───┐
* │c0 │c1 │ ... │c62│c63│
* ┌────┼───┼───┼─────┼───┼───┤
* │Blk0│ │ │ │ │ │ rows 0-31
* ├────┼───┼───┼─────┼───┼───┤
* │Blk1│ │ │ │ │ │ rows 32-63
* └────┴───┴───┴─────┴───┴───┘
*
* Physical view - Flattened in row-major order:
* Total elements: 64*64 = 4096
* ┌──────────────────────────────────────────────────────┐
* │[0...63][64...127][128...191]...[4032...4095] │
* └──────────────────────────────────────────────────────┘
* ^ ^
* start_offset=60 start_offset + input.length=130
*
* Row-wise amax output (one value per 1x32 block):
* ┌────────┬────────┐
* │ amax │ amax │ row0 (block0 and block1 partially covered)
* ├────────┼────────┤
* │ 0 │ 0 │ row1 (not covered)
* ├────────┼────────┤
* │ ... │ ... │
* ├────────┼────────┤
* │ 0 │ 0 │ row63 (not covered)
* └────────┴────────┘
*
* Column-wise amax output (one value per 32x1 block):
* ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┐
* │ amax │ amax │ amax │ amax │ amax │ amax │ amax │ ... row 0-31
* ├────────┼────────┼────────┼────────┼────────┼────────┼────────┤
* │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ ... row 32-62
* └────────┴────────┴────────┴────────┴────────┴────────┴────────┘
* col0 col1 col2 col3 col4 col5 col6
*
* For each 1x32 or 32x1 block, if it overlaps with [start_offset, start_offset+input.length),
* compute amax; otherwise set to 0.
*
* \param[in] input Input tensor (continuous segment of flattened original tensor).
* \param[in,out] amax_rowwise Output tensor for row-wise maximum absolute values.
* \param[in,out] amax_colwise Output tensor for column-wise maximum absolute values.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_mxfp8_scaling_compute_partial_amax
(
const
NVTETensor
input
,
NVTETensor
amax_rowwise
,
NVTETensor
amax_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
);
/*! \brief Perform partial MXFP8 casting.
*
* This function casts the input tensor to MXFP8 format, producing both row-wise and
* column-wise scaled outputs. input contains a continuous segment from the flattened
* original tensor.
*
* \param[in] input Input (continuous segment of flattened original tensor).
* \param[out] output_rowwise Output tensor with row-wise scaling (MXFP8 format).
* \param[out] output_colwise Output tensor with column-wise scaling (MXFP8 format).
* \param[in] scale_inv_rowwise Inverse scaling factors for row-wise scaling.
* \param[in] scale_inv_colwise Inverse scaling factors for column-wise scaling.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_mxfp8_scaling_partial_cast
(
const
NVTETensor
input
,
NVTETensor
output_rowwise
,
NVTETensor
output_colwise
,
const
NVTETensor
scale_inv_rowwise
,
const
NVTETensor
scale_inv_colwise
,
int
rows
,
int
cols
,
size_t
start_offset
,
cudaStream_t
stream
);
/*! \brief Compute per-tensor scaling factor for NVFP4 format.
*
* This function computes the scaling factor (alpha) for NVFP4 quantization based
* on the input tensors A and B, with options for using row-wise amax values.
*
* \param[in] inpA Input tensor A.
* \param[in] use_rowwise_amax_A Whether to use row-wise amax for tensor A.
* \param[in] inpB Input tensor B.
* \param[in] use_rowwise_amax_B Whether to use row-wise amax for tensor B.
* \param[in] alpha_in Input scaling factor.
* \param[out] alpha_out Output scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_nvfp4_compute_per_tensor_scale
(
const
NVTETensor
inpA
,
const
bool
use_rowwise_amax_A
,
const
NVTETensor
inpB
,
const
bool
use_rowwise_amax_B
,
float
alpha_in
,
NVTETensor
alpha_out
,
cudaStream_t
stream
);
...
...
transformer_engine/common/include/transformer_engine/softmax.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/include/transformer_engine/swizzle.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file
cast
.h
* \brief Functions to c
ast to/from FP8
.
/*! \file
swizzle
.h
* \brief Functions to c
onvert scaling factors into format expected by GEMM
.
*/
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_
...
...
@@ -34,6 +34,7 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
...
...
@@ -46,7 +47,7 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block
scal
ing
tensor
with GEMM_READY scale_inv
.
* \param[in] input Input FP8 block
-
scal
ed
tensor.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
...
...
@@ -56,7 +57,6 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -13,6 +13,7 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern
"C"
{
...
...
@@ -52,8 +53,11 @@ struct NVTEShape {
* It does not own the memory it points to.
*/
struct
NVTEBasicTensor
{
/*! Pointer to data buffer. */
void
*
data_ptr
;
/*! Data type. */
NVTEDType
dtype
;
/*! Tensor shape. */
NVTEShape
shape
;
};
...
...
@@ -61,13 +65,14 @@ struct NVTEBasicTensor {
* \brief Indicates the kind of the tensor parameter to set/get.
*/
enum
NVTETensorParam
{
kNVTERowwiseData
=
0
,
/*!< Data usable in rowwise manner */
kNVTEColumnwiseData
=
1
,
/*!< Data usable in columnwise manner */
kNVTEScale
=
2
,
/*!< Scale tensor */
kNVTEAmax
=
3
,
/*!< Amax tensor */
kNVTERowwiseScaleInv
=
4
,
/*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv
=
5
,
/*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax
=
6
,
/*!< Columnwise Amax tensor */
kNVTERowwiseData
=
0
,
/*!< Data usable in rowwise manner */
kNVTEColumnwiseData
=
1
,
/*!< Data usable in columnwise manner */
kNVTEScale
=
2
,
/*!< Scale tensor */
kNVTEAmax
=
3
,
/*!< Amax tensor */
kNVTERowwiseScaleInv
=
4
,
/*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv
=
5
,
/*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax
=
6
,
/*!< Columnwise Amax tensor */
kNVTEWithGEMMSwizzledScales
=
7
,
/*!< Whether scaling factors are in format expected by GEMM */
kNVTENumTensorParams
};
...
...
@@ -143,8 +148,9 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
* \param[data] Pointer to start of shape array. If NULL, the shape
* will be filled with zeros.
* \param[ndim] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
...
...
@@ -177,7 +183,7 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in]
size_t
Dimension index.
* \param[in]
dim
Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
...
...
@@ -258,12 +264,13 @@ NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor);
/*! \brief Reset tensor value to zero.
*
* \param[in] tensor Tensor.
*
* \return A scale_inv shape of the input tensor.
* \param[in] stream CUDA stream to use for the operation.
*/
void
nvte_zero_tensor
(
const
NVTETensor
tensor
,
cudaStream_t
stream
);
/*! \brief Set a parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in/out] tensor Tensor.
* \param[in] param_name The parameter to be set.
...
...
@@ -273,12 +280,38 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const
NVTEBasicTensor
*
param
);
/*! \brief Get a value of the parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in] tensor Tensor.
* \param[in] param_name The parameter to be set.
*/
NVTEBasicTensor
nvte_get_tensor_param
(
const
NVTETensor
tensor
,
NVTETensorParam
param_name
);
/*! \brief Set a tensor parameter.
*
* \param[in/out] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[in] buf Memory address to read parameter value.
* \param[in] size_in_bytes Size of buf.
*/
void
nvte_set_tensor_param_v2
(
NVTETensor
tensor
,
NVTETensorParam
param
,
const
void
*
buf
,
size_t
size_in_bytes
);
/*! \brief Query a tensor parameter.
*
* \param[in] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[out] buf Memory address to write parameter value.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void
nvte_get_tensor_param_v2
(
const
NVTETensor
tensor
,
NVTETensorParam
param
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Get the granularity of scaling of this tensor.
*
* \param[in] tensor Tensor.
...
...
@@ -324,12 +357,7 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor
=
2
,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
/*! \warning Deprecated */
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
=
3
,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState
=
4
,
...
...
@@ -337,6 +365,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization
=
5
,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding
=
6
,
/*! Whether to enable fast math operations with reduced accuracy.
*
* Optimizations are kernel-specific and they may be applied
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath
=
7
,
kNVTEQuantizationConfigNumAttributes
};
...
...
@@ -347,14 +381,14 @@ NVTEQuantizationConfig nvte_create_quantization_config();
/*! \brief Query an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value.
Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
* \param[in]
config
Quantization config.
* \param[in]
attr
Option type.
* \param[out] buf
Memory address to write option value.
*
Ignored if
NULL.
* \param[in]
size_in_bytes Size of buf.
* \param[out] size_written
Number of bytes that have been written to
*
buf. If buf is NULL, then the number of
*
bytes that would have been written.
*/
void
nvte_get_quantization_config_attribute
(
NVTEQuantizationConfig
config
,
NVTEQuantizationConfigAttribute
attr
,
void
*
buf
,
...
...
@@ -362,10 +396,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
/*! \brief Set an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in]
attr
Option type.
* \param[
out] buf
Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
* \param[in
/out
] config
Quantization config.
* \param[in]
attr
Option type.
* \param[
in] buf
Memory address to read option value.
* \param[in]
size_in_bytes Size of buf.
*/
void
nvte_set_quantization_config_attribute
(
NVTEQuantizationConfig
config
,
NVTEQuantizationConfigAttribute
attr
,
const
void
*
buf
,
...
...
@@ -394,6 +428,114 @@ int nvte_is_non_tn_fp8_gemm_supported();
*/
void
nvte_memset
(
void
*
ptr
,
int
value
,
size_t
size_in_bytes
,
cudaStream_t
stream
);
/*! \brief TE Grouped Tensor type
*
* NVTEGroupedTensor is a collection of tensors with potentially different shapes
* but the same dtype and scaling mode. It does not own the memory it points to.
*/
typedef
void
*
NVTEGroupedTensor
;
/*! \enum NVTEGroupedTensorParam
* \brief Indicates the kind of the grouped tensor parameter to set/get.
*/
enum
NVTEGroupedTensorParam
{
kNVTEGroupedRowwiseData
=
0
,
/*!< Data usable in rowwise manner */
kNVTEGroupedColumnwiseData
=
1
,
/*!< Data usable in columnwise manner */
kNVTEGroupedScale
=
2
,
/*!< Scale tensor */
kNVTEGroupedAmax
=
3
,
/*!< Amax tensor */
kNVTEGroupedRowwiseScaleInv
=
4
,
/*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEGroupedColumnwiseScaleInv
=
5
,
/*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEGroupedColumnwiseAmax
=
6
,
/*!< Columnwise Amax tensor */
kNVTEGroupedFirstDims
=
7
,
/*!< First dimension sizes (device pointer to int64_t array) */
kNVTEGroupedLastDims
=
8
,
/*!< Last dimension sizes (device pointer to int64_t array) */
kNVTEGroupedTensorOffsets
=
9
,
/*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */
kNVTENumGroupedTensorParams
};
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Create a new TE grouped tensor.
*
* Create a new TE grouped tensor. Before use its parameters need to be set.
* TE grouped tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] scaling_mode Scaling mode of the grouped tensor.
* \param[in] num_tensors Number of tensors in the group (must be > 0).
* \param[in] logical_shape Logical 2D shape of the grouped data.
*
* \return A new TE grouped tensor.
*/
NVTEGroupedTensor
nvte_create_grouped_tensor
(
NVTEScalingMode
scaling_mode
,
size_t
num_tensors
,
NVTEShape
logical_shape
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Destroy a TE grouped tensor.
*
* Since the TE grouped tensor does not own memory, the underlying
* data is not freed during this operation.
*
* \param[in] tensor Grouped tensor to be destroyed.
*/
void
nvte_destroy_grouped_tensor
(
NVTEGroupedTensor
tensor
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Set a parameter of the grouped tensor.
*
* \param[in/out] tensor Grouped tensor.
* \param[in] param_name The parameter to be set.
* \param[in] param The value to be set (NVTEBasicTensor).
*/
void
nvte_set_grouped_tensor_param
(
NVTEGroupedTensor
*
tensor
,
NVTEGroupedTensorParam
param_name
,
const
NVTEBasicTensor
*
param
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a value of the parameter of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
* \param[in] param_name The parameter to be queried.
*
* \return NVTEBasicTensor containing the parameter data.
*/
NVTEBasicTensor
nvte_get_grouped_tensor_param
(
const
NVTEGroupedTensor
tensor
,
NVTEGroupedTensorParam
param_name
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the number of tensors in a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Number of tensors in the group.
*/
size_t
nvte_grouped_tensor_num_tensors
(
const
NVTEGroupedTensor
tensor
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a grouped tensor's data type.
*
* \param[in] tensor Grouped tensor.
*
* \return A data type of the grouped tensor.
*/
NVTEDType
nvte_grouped_tensor_type
(
const
NVTEGroupedTensor
tensor
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a scaling mode of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Scaling mode of the grouped tensor.
*/
NVTEScalingMode
nvte_grouped_tensor_scaling_mode
(
const
NVTEGroupedTensor
tensor
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the logical shape of a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Logical 2D shape.
*/
NVTEShape
nvte_get_grouped_tensor_logical_shape
(
const
NVTEGroupedTensor
tensor
);
#ifdef __cplusplus
}
// extern "C"
...
...
@@ -435,7 +577,7 @@ inline bool is_int8_dtype(const DType t) {
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in]
DType
TE Datatype of interest
* \param[in]
t
TE Datatype of interest
*/
inline
bool
is_fp8_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat8E4M3
||
t
==
DType
::
kFloat8E5M2
;
...
...
@@ -444,7 +586,7 @@ inline bool is_fp8_dtype(const DType t) {
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in]
DType
TE Datatype of interest
* \param[in]
t
TE Datatype of interest
*/
inline
bool
is_fp4_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat4E2M1
;
...
...
@@ -453,7 +595,7 @@ inline bool is_fp4_dtype(const DType t) {
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in]
DType
TE Datatype of interest
* \param[in]
t
TE Datatype of interest
*/
inline
bool
is_high_precision_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat32
||
t
==
DType
::
kBFloat16
||
t
==
DType
::
kFloat16
;
...
...
@@ -477,20 +619,28 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper
(
void
*
dptr
,
const
NVTEShape
&
shape
,
const
DType
dtype
,
float
*
amax_dptr
=
nullptr
,
float
*
scale_dptr
=
nullptr
,
float
*
scale_inv_dptr
=
nullptr
,
const
NVTEShape
scale_inv_shape
=
defaultShape
,
NVTEShape
scale_inv_shape
=
defaultShape
,
const
NVTEScalingMode
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
)
{
tensor_
=
nvte_create_tensor
(
scaling_mode
);
NVTEBasicTensor
data
=
{
dptr
,
static_cast
<
NVTEDType
>
(
dtype
),
shape
};
nvte_set_tensor_param
(
&
tensor_
,
kNVTERowwiseData
,
&
data
);
NVTEBasicTensor
amax
=
{
amax_dptr
,
kNVTEFloat32
,
defaultShape
};
nvte_set_tensor_param
(
&
tensor_
,
kNVTEAmax
,
&
amax
);
NVTEBasicTensor
scale
=
{
scale_dptr
,
kNVTEFloat32
,
defaultShape
};
nvte_set_tensor_param
(
&
tensor_
,
kNVTEScale
,
&
scale
);
nvte_set_tensor_param_v2
(
tensor_
,
kNVTERowwiseData
,
&
data
,
sizeof
(
data
));
NVTEBasicTensor
amax
=
{
amax_dptr
,
kNVTEFloat32
,
amax_dptr
!=
nullptr
?
defaultShape
:
emptyShape
};
nvte_set_tensor_param_v2
(
tensor_
,
kNVTEAmax
,
&
amax
,
sizeof
(
amax
));
NVTEBasicTensor
scale
=
{
scale_dptr
,
kNVTEFloat32
,
scale_dptr
!=
nullptr
?
defaultShape
:
emptyShape
};
nvte_set_tensor_param_v2
(
tensor_
,
kNVTEScale
,
&
scale
,
sizeof
(
scale
));
if
(
scale_inv_dptr
==
nullptr
&&
scale_inv_shape
.
ndim
==
defaultShape
.
ndim
&&
scale_inv_shape
.
ndim
==
1
&&
scale_inv_shape
.
data
[
0
]
==
defaultShape
.
data
[
0
])
{
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape
=
emptyShape
;
}
NVTEBasicTensor
scale_inv
=
{
scale_inv_dptr
,
kNVTEFloat32
,
scale_inv_shape
};
nvte_set_tensor_param
(
&
tensor_
,
kNVTERowwiseScaleInv
,
&
scale_inv
);
nvte_set_tensor_param
_v2
(
tensor_
,
kNVTERowwiseScaleInv
,
&
scale_inv
,
sizeof
(
scale_inv
)
);
}
/*! \brief Constructs new TensorWrapper.
...
...
@@ -506,6 +656,7 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper
(
void
*
dptr
,
const
std
::
vector
<
size_t
>
&
shape
,
const
DType
dtype
,
float
*
amax_dptr
=
nullptr
,
float
*
scale_dptr
=
nullptr
,
...
...
@@ -560,7 +711,7 @@ class TensorWrapper {
const
ShapeType
&
shape
)
noexcept
{
NVTEShape
nvte_shape
=
this
->
convertShape
(
shape
);
NVTEBasicTensor
data
=
{
dptr
,
static_cast
<
NVTEDType
>
(
type
),
nvte_shape
};
nvte_set_tensor_param
(
&
tensor_
,
param
,
&
data
);
nvte_set_tensor_param
_v2
(
tensor_
,
param
,
&
data
,
sizeof
(
data
)
);
return
*
this
;
}
...
...
@@ -599,10 +750,17 @@ class TensorWrapper {
return
set_parameter
(
kNVTEColumnwiseAmax
,
dptr
,
type
,
shape
);
}
void
set_with_gemm_swizzled_scales
(
bool
with_gemm_swizzled_scales
)
{
const
auto
val
=
static_cast
<
uint8_t
>
(
with_gemm_swizzled_scales
);
nvte_set_tensor_param_v2
(
tensor_
,
kNVTEWithGEMMSwizzledScales
,
&
val
,
sizeof
(
val
));
}
// Parameter getters
NVTEBasicTensor
get_parameter
(
const
NVTETensorParam
param
)
const
noexcept
{
return
nvte_get_tensor_param
(
tensor_
,
param
);
NVTEBasicTensor
ret
;
nvte_get_tensor_param_v2
(
tensor_
,
param
,
&
ret
,
sizeof
(
ret
),
nullptr
);
return
ret
;
}
NVTEBasicTensor
get_rowwise_data
()
const
noexcept
{
return
get_parameter
(
kNVTERowwiseData
);
}
...
...
@@ -627,6 +785,12 @@ class TensorWrapper {
return
get_parameter
(
kNVTEColumnwiseAmax
);
}
bool
get_with_gemm_swizzled_scales
()
const
{
uint8_t
val
=
0
;
nvte_get_tensor_param_v2
(
tensor_
,
kNVTEWithGEMMSwizzledScales
,
&
val
,
sizeof
(
val
),
nullptr
);
return
static_cast
<
bool
>
(
val
);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
...
...
@@ -639,7 +803,7 @@ class TensorWrapper {
*/
const
NVTEShape
shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
)
;
return
emptyShape
;
}
return
nvte_tensor_shape
(
tensor_
);
}
...
...
@@ -650,14 +814,14 @@ class TensorWrapper {
*/
const
NVTEShape
columnwise_shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
)
;
return
emptyShape
;
}
return
nvte_tensor_columnwise_shape
(
tensor_
);
}
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in]
size_t
Dimension index.
* \param[in]
dim
Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
...
...
@@ -774,7 +938,7 @@ class TensorWrapper {
*/
const
NVTEShape
scale_inv_shape
()
const
noexcept
{
if
(
tensor_
==
nullptr
)
{
return
nvte_make_shape
(
nullptr
,
0
)
;
return
emptyShape
;
}
return
nvte_tensor_scale_inv_shape
(
tensor_
);
}
...
...
@@ -793,6 +957,7 @@ class TensorWrapper {
static
constexpr
size_t
defaultData
=
1
;
static
constexpr
NVTEShape
defaultShape
=
{
{
defaultData
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
1
};
static
constexpr
NVTEShape
emptyShape
=
{{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
1
};
private:
NVTEShape
convertShape
(
const
NVTEShape
&
s
)
{
return
s
;
}
...
...
@@ -805,15 +970,8 @@ class TensorWrapper {
NVTETensor
tensor_
=
nullptr
;
};
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum
class
Float8BlockScaleTensorFormat
{
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY
=
0
,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT
=
1
};
/*! \warning Deprecated */
enum
class
Float8BlockScaleTensorFormat
{
GEMM_READY
=
0
,
COMPACT
=
1
,
INVALID
};
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
...
...
@@ -825,9 +983,11 @@ class QuantizationConfigWrapper {
QuantizationConfigWrapper
(
const
QuantizationConfigWrapper
&
)
=
delete
;
QuantizationConfigWrapper
&
operator
=
(
const
QuantizationConfigWrapper
&
)
=
delete
;
/*! \brief Move constructor. */
QuantizationConfigWrapper
(
QuantizationConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
/*! \brief Move-assignment operator. */
QuantizationConfigWrapper
&
operator
=
(
QuantizationConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_quantization_config
(
config_
);
...
...
@@ -852,8 +1012,9 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to force power of 2 scales */
void
set_force_pow_2_scales
(
bool
force_pow_2_scales
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigForcePow2Scales
,
&
force_pow_2_scales
,
sizeof
(
bool
));
const
auto
val
=
static_cast
<
uint8_t
>
(
force_pow_2_scales
);
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigForcePow2Scales
,
&
val
,
sizeof
(
val
));
}
/*! \brief Set small value to add to amax */
...
...
@@ -868,12 +1029,8 @@ class QuantizationConfigWrapper {
sizeof
(
NVTETensor
));
}
/*! \brief Set FP8 block-scaled tensor format */
void
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
format
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
,
&
format
,
sizeof
(
Float8BlockScaleTensorFormat
));
}
/*! \warning Deprecated */
void
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
format
)
{}
/*! \brief Set stochastic rounding state */
void
set_rng_state
(
NVTETensor
rng_state
)
{
...
...
@@ -883,14 +1040,23 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void
set_nvfp4_2d_quantization
(
bool
nvfp4_2d_quantization
)
{
const
auto
val
=
static_cast
<
uint8_t
>
(
nvfp4_2d_quantization
);
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigNVFP42DQuantization
,
&
nvfp4_2d_quantization
,
sizeof
(
boo
l
));
&
val
,
sizeof
(
va
l
));
}
/*! \brief Set whether to use stochastic rounding */
void
set_stochastic_rounding
(
bool
stochastic_rounding
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigStochasticRounding
,
&
stochastic_rounding
,
sizeof
(
bool
));
const
auto
val
=
static_cast
<
uint8_t
>
(
stochastic_rounding
);
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigStochasticRounding
,
&
val
,
sizeof
(
val
));
}
/*! \brief Set whether to enable fast math operations */
void
set_use_fast_math
(
bool
use_fast_math
)
{
const
auto
val
=
static_cast
<
uint8_t
>
(
use_fast_math
);
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigUseFastMath
,
&
val
,
sizeof
(
val
));
}
private:
...
...
transformer_engine/common/include/transformer_engine/transpose.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -231,7 +231,7 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor a
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in]
gated_
act_input Tensor used as input to the forward of
* \param[in] act_input
Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
...
...
@@ -250,7 +250,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in]
gated_
act_input Tensor used as input to the forward of
* \param[in] act_input
Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
...
...
@@ -269,7 +269,7 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in]
gated_
act_input Tensor used as input to the forward of
* \param[in] act_input
Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
...
...
@@ -288,7 +288,7 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in]
gated_
act_input Tensor used as input to the forward of
* \param[in] act_input
Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
...
...
@@ -307,7 +307,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in]
gated_
act_input Tensor used as input to the forward of
* \param[in] act_input
Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/multi_tensor/compute_scale.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -17,6 +17,7 @@
#include <sstream>
#include "../recipe/recipe_common.cuh"
#include "../util/ptx.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
...
...
@@ -58,6 +59,28 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
struct
ComputeScaleInvE8M0Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
unused
,
TensorListMetadata
<
2
>
&
tl
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
bf16
*
amax
=
reinterpret_cast
<
bf16
*>
(
tl
.
addresses
[
0
][
tensor_loc
]);
amax
+=
chunk_idx
*
chunk_size
;
e8m0_t
*
scale_inv
=
reinterpret_cast
<
e8m0_t
*>
(
tl
.
addresses
[
1
][
tensor_loc
]);
scale_inv
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
scale_inv
[
i_start
]
=
ptx
::
float_to_e8m0
(
static_cast
<
float
>
(
amax
[
i_start
])
*
Quantized_Limits
<
fp8e4m3
>::
max_norm_rcp
);
}
}
};
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
...
...
@@ -68,6 +91,19 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_compute_scale_inv_e8m0_cuda
(
int
chunk_size
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
tensor_lists
[
0
][
0
]
->
data
.
dtype
==
DType
::
kBFloat16
,
"amax should be bf16"
);
auto
scale_inv_dtype
=
tensor_lists
[
1
][
0
]
->
data
.
dtype
;
NVTE_CHECK
(
scale_inv_dtype
==
DType
::
kByte
||
scale_inv_dtype
==
DType
::
kFloat8E8M0
,
"scale_inv should be e8m0/uint8"
);
Tensor
dummy
;
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
dummy
,
tensor_lists
,
ComputeScaleInvE8M0Functor
(),
stream
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// namespace multi_tensor_compute_scale
}
// namespace transformer_engine
...
...
@@ -85,3 +121,15 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
max_fp8
,
force_pow_2_scales
,
epsilon
,
stream
);
}
void
nvte_multi_tensor_compute_scale_inv_e8m0_cuda
(
int
chunk_size
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_compute_scale_inv_e8m0_cuda
);
using
namespace
transformer_engine
;
multi_tensor_compute_scale
::
multi_tensor_compute_scale_inv_e8m0_cuda
(
chunk_size
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
stream
);
}
transformer_engine/common/multi_tensor/l2norm.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment