Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3534 additions
and
139 deletions
+3534
-139
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+9
-0
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+6
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+6
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+842
-0
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+94
-55
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+761
-46
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+96
-10
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+10
-0
transformer_engine/common/util/math.h
transformer_engine/common/util/math.h
+35
-4
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+1516
-0
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+79
-3
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+5
-0
transformer_engine/common/util/vectorized_pointwise.h
transformer_engine/common/util/vectorized_pointwise.h
+20
-6
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+20
-0
transformer_engine/debug/features/fake_quant.py
transformer_engine/debug/features/fake_quant.py
+1
-1
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+9
-3
transformer_engine/debug/features/log_tensor_stats.py
transformer_engine/debug/features/log_tensor_stats.py
+12
-6
transformer_engine/debug/features/utils/stats_buffer.py
transformer_engine/debug/features/utils/stats_buffer.py
+9
-1
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+2
-2
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+2
-2
No files found.
transformer_engine/common/transpose/cast_transpose.h
View file @
063ef88d
...
...
@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
detail
{
...
...
@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const
bool
pow_2_scale
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
063ef88d
...
...
@@ -18,6 +18,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
...
...
@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
checkCuDriverContext
(
stream
);
if
(
transformer_engine
::
cuda
::
sm_arch
()
>=
100
)
{
NVTE_CHECK
(
pow_2_scale
,
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
,
"with MXFP8, which requires using power of two scaling factors."
);
}
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_rows
=
1
;
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
063ef88d
...
...
@@ -24,6 +24,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace
transformer_engine
{
...
...
@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise
);
if
(
transformer_engine
::
cuda
::
sm_arch
()
>=
100
)
{
NVTE_CHECK
(
pow2_scale
,
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
,
"with MXFP8, which requires using power of two scaling factors."
);
}
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace
transformer_engine
{
#if CUDA_VERSION >= 12080
namespace
quantize_transpose_nvfp4
{
namespace
{
using
std
::
int32_t
;
using
std
::
uint32_t
;
using
std
::
uint8_t
;
using
transformer_engine
::
detail
::
TypeExtrema
;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr
int
kThreadsPerWarp
=
32
;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr
int
kNFP4PerContainer
=
2
;
// Hyperparameters for performance tuning
constexpr
int
kTileDim
=
128
;
// constexpr int kScaleDim = 32;
constexpr
int
kNVecIn
=
8
;
// The number of elements each LDG touches
constexpr
int
kNVecOut
=
16
;
// The number of elements each STG touches
constexpr
int
kNVecSMem
=
2
;
// The number of elements each LDS/STS touches
constexpr
int
kThreadsPerBlock
=
256
;
// Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert
(
kNVecIn
%
kNVecSMem
==
0
,
"kNVecIn must be divisible by kNVecSMem"
);
static_assert
(
kNVecOut
%
kNVecSMem
==
0
,
"kNVecOut must be divisible by kNVecSMem"
);
constexpr
int
kSMemRow
=
kTileDim
;
constexpr
int
kSMemCol
=
(
kTileDim
/
kNVecSMem
)
+
1
;
constexpr
int
kSMemSize
=
kSMemRow
*
kSMemCol
*
kNVecSMem
;
constexpr
int
kNumThreadsLoad
=
kTileDim
/
kNVecIn
;
// 16
constexpr
int
kNumThreadsStore
=
kTileDim
/
kNVecOut
;
// 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert
(
kNumThreadsLoad
<=
kThreadsPerWarp
,
"kNumThreadsLoad must be <= kThreadsPerWarp"
);
static_assert
(
kNumThreadsStore
<=
kThreadsPerWarp
,
"kNumThreadsStore must be <= kThreadsPerWarp"
);
// for 2D block scaling, we need to reduce amax in warp
static
__device__
constexpr
unsigned
int
WARP_REDUCE_AMAX_GROUP_MASKS
[
8
]
=
{
0x01010101
,
0x02020202
,
0x04040404
,
0x08080808
,
0x10101010
,
0x20202020
,
0x40404040
,
0x80808080
};
// max for every group_size elements in warp
template
<
int
group_size
,
int
shfl_down_stride
>
__device__
__forceinline__
float
groupMax
(
float
val
,
unsigned
int
groupMask
)
{
for
(
int
offset
=
group_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
val
=
max
(
val
,
__shfl_down_sync
(
groupMask
,
val
,
offset
*
shfl_down_stride
));
}
return
val
;
}
template
<
typename
ScaleType
>
__device__
__forceinline__
ScaleType
ComputeDecodeScaleFP4
(
const
float
amax
,
const
float
global_encode_scale
)
{
float
decode_scale
=
amax
/
TypeExtrema
<
fp4e2m1
>::
max
;
decode_scale
=
decode_scale
*
global_encode_scale
;
decode_scale
=
fminf
(
decode_scale
,
TypeExtrema
<
float
>::
max
);
return
static_cast
<
ScaleType
>
(
decode_scale
);
}
template
<
typename
ScaleType
>
__device__
__forceinline__
float
ComputeEncodeScaleFP4
(
ScaleType
decode_scale
,
const
float
global_decode_scale
)
{
return
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
decode_scale
)
*
global_decode_scale
),
TypeExtrema
<
float
>::
max
);
}
template
<
typename
IType
,
typename
ScaleType
>
__device__
__forceinline__
float
ComputeOutputFP4
(
IType
input
,
float
encode_scale
)
{
return
static_cast
<
float
>
(
input
)
*
encode_scale
;
}
__device__
__forceinline__
float
ComputeGlobalEncodeScaleFP4
(
const
float
global_amax
)
{
constexpr
float
fp8_max
=
TypeExtrema
<
fp8e4m3
>::
max
;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
TypeExtrema
<
float
>::
max
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.
f
||
global_encode_scale
==
0.
f
)
{
return
1.
f
;
}
return
global_encode_scale
;
}
__device__
__forceinline__
uint32_t
get_rbits
(
RNG
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
if
(
rnd_idx
==
4
)
{
rnd_idx
=
0
;
curanddx
::
uniform_bits
dist
;
random_uint4
=
dist
.
generate4
(
rng
);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const
uint32_t
*
const
rbits_arr
=
reinterpret_cast
<
uint32_t
*>
(
&
random_uint4
);
const
uint32_t
rbits
=
rbits_arr
[
rnd_idx
++
];
return
rbits
;
}
template
<
class
ScaleType
>
__device__
__forceinline__
size_t
scale_factor_swizzled_offset
(
size_t
row_idx
,
size_t
col_idx
,
uint32_t
col_length
)
{
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr
uint32_t
kTotalRowsPerBaseBlock
=
128
;
constexpr
uint32_t
kRowsPerBaseBlockCol
=
32
;
constexpr
uint32_t
kColsPerBaseBlockCol
=
4
;
const
size_t
rb
=
row_idx
/
kTotalRowsPerBaseBlock
;
const
size_t
rem
=
row_idx
%
kTotalRowsPerBaseBlock
;
const
size_t
d4
=
rem
/
kRowsPerBaseBlockCol
;
const
size_t
d3
=
rem
%
kRowsPerBaseBlockCol
;
const
size_t
cbg
=
col_idx
/
kColsPerBaseBlockCol
;
const
size_t
d5
=
col_idx
%
kColsPerBaseBlockCol
;
const
size_t
cbg_cnt
=
DIVUP
(
col_length
,
kColsPerBaseBlockCol
);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return
((
rb
*
cbg_cnt
+
cbg
)
*
kRowsPerBaseBlockCol
+
d3
)
*
16
+
d4
*
kColsPerBaseBlockCol
+
d5
;
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t
out_4x
;
asm
volatile
(
"{
\n
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5;
\n\t
"
"}"
:
"=h"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
asm
volatile
(
"{
\n
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template
<
bool
kApplyStochasticRounding
>
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
if
constexpr
(
kApplyStochasticRounding
)
{
return
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
in01
,
in23
,
rbits
);
}
else
{
return
cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
rbits
);
}
}
template
<
bool
kReturnIdentity
,
bool
kReturnTranspose
,
bool
kIsE8Scaling
,
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
,
typename
ScaleType
,
bool
kSwizzledScale
,
bool
kApplyStochasticRounding
,
bool
kIs2DBlockScaling
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel
(
const
IType
*
const
input
,
const
float
*
global_amax
,
OType
*
const
output_c
,
OType
*
const
output_t
,
ScaleType
*
const
tile_scales_inv_c
,
ScaleType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
size_t
kScaleBlockDim
,
const
float
epsilon
,
const
size_t
*
rng_state
,
const
float
*
noop_ptr
)
{
constexpr
int
kNVecContainer
=
kNVecOut
/
kNFP4PerContainer
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecContainer
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
const
size_t
block_idx_x
=
blockIdx
.
x
;
const
size_t
block_idx_y
=
blockIdx
.
y
;
const
size_t
rng_sequence
=
threadIdx
.
x
+
block_idx_x
*
kThreadsPerBlock
+
block_idx_y
*
gridDim
.
x
*
kThreadsPerBlock
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
kApplyStochasticRounding
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if
constexpr
(
kIs2DBlockScaling
&&
kIsE8Scaling
)
{
return
;
}
if
constexpr
(
kIs2DBlockScaling
&&
!
kReturnIdentity
)
{
return
;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr
int
kFP4BlockScalingSize
=
16
;
constexpr
int
k2DBlockAmaxDim
=
kIs2DBlockScaling
?
(
kTileDim
/
kFP4BlockScalingSize
)
:
1
;
constexpr
int
kNumRowsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
;
// 4
constexpr
int
k2DBlockAmaxReduceDim
=
kIs2DBlockScaling
?
(
kFP4BlockScalingSize
/
kNumRowsPerWarp
)
:
1
;
__shared__
CType
amax_smem_red
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
][
k2DBlockAmaxReduceDim
];
__shared__
CType
amax_smem
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
];
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
);
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
const
int
kNumThreadsReduce
=
kScaleBlockDim
/
kNVecOut
;
const
float
global_encode_scale
=
kIsE8Scaling
?
1.0
f
:
ComputeGlobalEncodeScaleFP4
(
global_amax
[
0
]);
const
float
global_decode_scale
=
1.0
/
global_encode_scale
;
// Step 2: Cast and store to output_c
if
constexpr
(
kReturnIdentity
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
row_length
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_c
[(
r_g
*
row_length
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if
constexpr
(
kIs2DBlockScaling
)
{
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr
int
kNumRowsPerIter
=
kThreadsPerBlock
/
kNumThreadsStore
;
// 32
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
int
tid_in_warp_x
=
threadIdx
.
x
%
kNumThreadsStore
;
int
tid_in_warp_y
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumRowsPerWarp
;
CType
amax_warp_reduced
=
groupMax
<
kNumRowsPerWarp
,
kNumThreadsStore
>
(
amax
,
WARP_REDUCE_AMAX_GROUP_MASKS
[
tid_in_warp_x
]);
// now T0 ~ T8 in each warp has the reduced amax values
int
data_row_idx
=
iter
*
kNumRowsPerIter
+
warp_idx
*
kNumRowsPerWarp
+
tid_in_warp_y
;
if
(
tid_in_warp_y
==
0
)
{
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
[
warp_idx
%
k2DBlockAmaxReduceDim
]
=
amax_warp_reduced
;
}
__syncthreads
();
if
(
data_row_idx
%
kFP4BlockScalingSize
==
0
)
{
CType
amax_2d
=
0.0
;
for
(
int
i
=
0
;
i
<
k2DBlockAmaxReduceDim
;
i
++
)
{
amax_2d
=
fmaxf
(
amax_2d
,
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
][
i
]);
}
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
=
amax_2d
;
}
__syncthreads
();
// every thread now knows 2D amax
amax
=
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
];
}
// Step 2.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
write_scale_inv
&=
(
c_g
<
row_length
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_y
*
kTileDim
+
r_s
;
size_t
col_idx
=
block_idx_x
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
;
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
row_length
,
kScaleBlockDim
));
tile_scales_inv_c
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
0
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
1
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
0
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
1
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
constexpr
(
kReturnTranspose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
num_rows
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_t
[(
r_g
*
num_rows
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
if
constexpr
(
kIs2DBlockScaling
)
{
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
constexpr
int
kNumColsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
*
kNVecSMem
;
// 8 elements
constexpr
int
kNumWarpsPerBlock
=
kThreadsPerBlock
/
kThreadsPerWarp
;
// 8 warps per block
constexpr
int
kNumColsPerIter
=
kNumColsPerWarp
*
kNumWarpsPerBlock
;
int
tid_in_warp_x
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumColsPerWarp
;
int
tid_in_warp_y
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
%
kNumThreadsStore
;
int
data_col_idx
=
iter
*
kNumColsPerIter
+
warp_idx
*
kNumColsPerWarp
+
tid_in_warp_x
;
amax
=
amax_smem
[
tid_in_warp_y
][
data_col_idx
/
kFP4BlockScalingSize
];
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
}
// Step 3.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// Step 3.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
write_scale_inv
&=
(
c_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
(
block_idx_y
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
);
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
num_rows
,
kScaleBlockDim
));
tile_scales_inv_t
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNFP4PerContainer
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
}
// namespace
}
// namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace
detail
{
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise_fp4
);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK
(
!
pow2_scale
,
"No support for pow2_scale for MXFP4 for now"
);
if
(
!
return_identity
&&
!
return_transpose
)
{
return
;
}
if
(
use_2d_quantization
&&
!
return_identity
)
{
return
;
}
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
num_elements
*=
input
.
shape
.
at
(
i
);
}
// Early return if the input tensor is empty
if
(
num_elements
==
0
)
{
return
;
}
size_t
scale_stride_x
=
0
;
size_t
scale_stride_y
=
0
;
if
(
return_identity
)
{
scale_stride_x
=
1
;
scale_stride_y
=
scale_inv
.
shape
[
1
];
}
size_t
scale_t_stride_x
=
0
;
size_t
scale_t_stride_y
=
0
;
if
(
return_transpose
)
{
scale_t_stride_x
=
1
;
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
}
using
namespace
transformer_engine
::
quantize_transpose_nvfp4
;
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
static_cast
<
size_t
>
(
kTileDim
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
static_cast
<
size_t
>
(
kTileDim
));
// noop tensor for cuda graph
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop_tensor
.
dptr
);
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY
(
output
.
dtype
,
2
,
OutputType
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
using
ScaleType
=
fp8e4m3
;
constexpr
int
kScaleBlockDim
=
16
;
constexpr
bool
kPow2Scale
=
false
;
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity
,
kReturnIdentity
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
kReturnTranspose
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
swizzled_scale
,
kSwizzledScale
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kApplyStochasticRounding
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_2d_quantization
,
kIs2DBlockScaling
,
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
auto
kernel
=
block_scaled_1d_cast_transpose_kernel
<
kReturnIdentity
,
kReturnTranspose
,
kPow2Scale
,
kAligned
,
float
,
InputType
,
OutputType
,
ScaleType
,
kSwizzledScale
,
kApplyStochasticRounding
,
kIs2DBlockScaling
>
;
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
kernel
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
float
*>
(
global_amax
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
kScaleBlockDim
,
epsilon
,
rng_state
,
noop_ptr
);)
// kIs2DBlockScaling
)
// kApplyStochasticRounding
)
// kSwizzledScale
)
// kAligned
)
// kReturnTranspose
)
// kReturnIdentity
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace detail
}
// namespace transformer_engine
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
063ef88d
...
...
@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
__grid_constant__
CUtensorMap
tensor_map_output_act
,
const
__grid_constant__
CUtensorMap
tensor_map_output_gate
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
)
{
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
ParamOP
p
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const
size_t
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
...
...
@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType
*
in_gate_sh_curr
=
in_gate_sh
+
buff
*
buff_elems
;
OType
*
out_act_sh_curr
=
out_act_sh
+
buff
*
buff_elems
;
OType
*
out_gate_sh_curr
=
out_gate_sh
+
buff
*
buff_elems
;
#pragma unroll
for
(
int
stage
=
0
;
stage
<
BUFFER_STAGES_NUM
;
++
stage
)
{
const
size_t
stage_offset_Y
=
stage
*
THREADS_PER_CHUNK_Y
;
...
...
@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
act_elt
=
static_cast
<
float
>
(
in_act_sh_curr
[
shmem_idx
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh_curr
[
shmem_idx
]);
bool
dgate_elt
=
true
;
// gating is ideally an identity function
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1
;
}
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh_curr
[
shmem_idx
]);
...
...
@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
float
x
=
act_elt
;
float
act_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
if
(
act_elt
<=
p
.
limit
)
{
dact_x
=
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
;
}
else
{
dact_x
=
0.0
f
;
}
}
else
{
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
{});
dact_x
=
DActOP
(
x
,
{});
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
float
after_dact
=
dact_x
*
grad_elt
*
gate_elt
;
float
after_dgate
=
act_x
*
grad_elt
;
float
after_dgate
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dact
);
out_gate_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dgate
);
...
...
@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax
=
fmaxf
(
amax
,
fabsf
(
after_dact
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_dgate
));
}
else
{
const
float
after_act
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
const
float
after_act
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_act
);
amax
=
fmaxf
(
amax
,
fabsf
(
after_act
));
}
...
...
@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
__grid_constant__
CUtensorMap
tensor_map_output_gate_colwise
,
e8m0_t
*
const
scales_rowwise
,
e8m0_t
*
const
scales_colwise
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
scale_stride_colwise
)
{
const
size_t
scale_stride_colwise
,
const
ParamOP
p
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
OType2
=
typename
ptx
::
FPx2
<
OType
>
;
...
...
@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh
[
shmem_offset_colwise
]);
float
after_act_elt
;
float
after_gate_elt
;
bool
dgate_elt
=
true
;
// gating is ideally an identity function
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1.0
f
;
}
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh
[
shmem_offset_colwise
]);
const
float
x
=
act_elt
;
float
act_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
dact_x
=
act_elt
<=
p
.
limit
?
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
:
0.0
f
;
}
else
{
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
{});
dact_x
=
DActOP
(
x
,
{});
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_gate_elt
=
act_x
*
grad_elt
;
after_gate_elt
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
}
else
{
after_act_elt
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
after_act_elt
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
...
...
@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if
constexpr
(
IS_DGATED
)
{
const
e8m0_t
biased_exponent_gate
=
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
...
...
@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
gate_elt
=
static_cast
<
float
>
(
in_gate
.
data
.
elt
[
e
]);
float
after_act_elt
;
float
after_gate_elt
;
bool
dgate_elt
=
true
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1.0
f
;
}
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad
.
data
.
elt
[
e
]);
const
float
x
=
act_elt
;
float
act_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
dact_x
=
act_elt
<=
p
.
limit
?
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
:
0.0
f
;
}
else
{
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
{});
dact_x
=
DActOP
(
x
,
{});
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_gate_elt
=
act_x
*
grad_elt
;
after_gate_elt
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
after_act_rowwise
[
j
]
=
after_act_elt
;
after_gate_rowwise
[
j
]
=
after_gate_elt
;
}
else
{
after_act_elt
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
after_act_elt
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
after_act_rowwise
[
j
]
=
after_act_elt
;
}
...
...
@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx
::
mul_cvt_2x
(
out_gate_pair
,
in_gate
,
block_scale_inverse_2x_gate
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
...
...
@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
...
...
@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
buff_size_aligned_out
;
const
size_t
shmem_size
=
grad_mem
+
(
in_act_mem
+
in_gate_mem
)
+
(
out_act_mem
+
out_gate_mem
)
+
TMA_SHMEM_ALIGNMENT
;
...
...
@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
>
<<<
grid_dim
,
block_dim
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act
,
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
);
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
#endif
...
...
@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
...
...
@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
case
ScalingType
::
COLWISE
:
...
...
@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
case
ScalingType
::
BIDIMENSIONAL
:
...
...
@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType
,
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
mxfp8_kernel
::
cast_mxfp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
...
...
@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
});
// NOLINT(*)
...
...
@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
void
cast_gated
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_gated
(
const
Tensor
&
input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
CheckInputTensor
(
input
,
"gated_act_input"
);
CheckOutputTensor
(
*
output
,
"gated_act_output"
);
NVTE_CHECK
(
output
->
flat_first_dim
()
==
input
.
flat_first_dim
(),
"Wrong output shape. Expected (after flattening) ["
,
input
.
flat_first_dim
(),
", *], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
input
.
flat_last_dim
()
%
2
==
0
,
"Wrong input shape. Expected (after flattening) last dimension to be even, "
,
"got ["
,
input
.
flat_first_dim
(),
", "
,
input
.
flat_last_dim
(),
"]."
);
...
...
@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
input
.
flat_first_dim
(),
output
->
flat_last_dim
(),
{}
,
stream
);
output
->
flat_last_dim
(),
p
,
stream
);
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
...
...
@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_dgated
(
const
Tensor
&
grad
,
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_dgated
(
const
Tensor
&
grad
,
const
Tensor
&
input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
CheckInputTensor
(
grad
,
"dgated_act_grad"
);
CheckInputTensor
(
input
,
"dgated_act_input"
);
CheckOutputTensor
(
*
output
,
"dgated_act_output"
);
...
...
@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
grad
.
flat_first_dim
(),
grad
.
flat_last_dim
(),
{}
,
stream
);
grad
.
flat_last_dim
(),
p
,
stream
);
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
...
...
@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
quantize_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
constexpr
bool
allow_empty
=
false
;
CheckInputTensor
(
gated_input
,
"gated_input"
);
...
...
@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if
(
is_delayed_tensor_scaling
(
output
->
scaling_mode
))
{
if
(
use_tma_kernels
)
{
cast_fp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_fp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
if
constexpr
(
IS_DGATED
)
{
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input
,
output
,
stream
);
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input
,
output
,
p
,
stream
);
}
}
}
else
if
(
is_mxfp_scaling
(
output
->
scaling_mode
))
{
}
else
if
(
is_mxfp
8
_scaling
(
output
->
scaling_mode
))
{
if
(
use_tma_kernels
)
{
cast_mxfp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_mxfp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
NVTE_ERROR
(
"Invalid input shape. Expected the last dimension to be divisible "
,
"by 32, got input of shape "
,
gated_input
.
data
.
shape
);
...
...
@@ -1292,7 +1330,7 @@ namespace detail {
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated_helper
(
const
NVTETensor
grad
,
const
NVTETensor
gated_input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
ParamOP
p
,
cudaStream_t
stream
)
{
using
namespace
gated_kernels
;
Tensor
grad_empty_tensor
;
const
Tensor
&
grad_tensor
=
IS_DGATED
?
*
(
convertNVTETensorCheck
(
grad
))
:
grad_empty_tensor
;
...
...
@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if
(
is_supported_by_CC_100
())
{
quantize_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
stream
);
output_tensor
,
p
,
stream
);
}
else
{
if
(
is_delayed_tensor_scaling
(
output_tensor
->
scaling_mode
))
{
if
constexpr
(
IS_DGATED
)
{
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
stream
);
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
p
,
stream
);
}
else
{
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input_tensor
,
output_tensor
,
stream
);
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input_tensor
,
output_tensor
,
p
,
stream
);
}
}
else
{
// MX scaling
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
063ef88d
...
...
@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
...
...
@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
size_t
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
...
...
@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
act_in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
elt_input_mem
);
OType
*
out_rowwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
OType
*
out_rowwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
...
...
@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
size_t
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
out_colwise_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
out_colwise_
data_
sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
...
...
@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
size_t
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
const
int
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
int
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
int
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
if
(
rowwise_scale_is_within_bounds
)
{
scales_rowwise
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
...
...
@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
out
.
store_to
(
&
out_rowwise_sh
[
shmem_offset_rowwise
]);
out
.
store_to
(
&
out_rowwise_
data_
sh
[
shmem_offset_rowwise
]);
}
}
...
...
@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_
t
global_offset_X
=
block_offset_X
;
const
size_
t
buff_offset
=
buff
*
BUFF_DIM
;
const
in
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
in
t
global_offset_X
=
block_offset_X
;
const
in
t
buff_offset
=
buff
*
BUFF_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_
data_
sh
[
buff_offset
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_
data_
sh
[
buff_offset
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
...
...
@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float
*
partial_dbias_rowwise
=
reinterpret_cast
<
float
*>
(
dshmem
);
constexpr
size_
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
constexpr
in
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
const
size_
t
shmem_thread_offset
=
const
in
t
shmem_thread_offset
=
tid_Y_rowwise
*
DBIAS_BUFF_WIDTH
+
tid_X_rowwise
*
(
SCALE_DIM_X
+
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
const
in
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
in
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
const
size_
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
const
in
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
partial_dbias_rowwise
[
shmem_elt_idx
]
=
thread_dbias_rowwise
[
j
];
}
}
...
...
@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for
(
int
i
=
0
;
i
<
THREADS_Y
;
++
i
)
{
// Add extra element offset per MXFP8 scaling block [1x32]
const
size_
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
const
in
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
thread_partial_dbias
+=
partial_dbias_rowwise
[
i
*
DBIAS_BUFF_WIDTH
+
threadIdx
.
x
+
scaling_block
];
}
}
const
size_
t
dbias_stride
=
cols
;
const
size_
t
dbias_offset_Y
=
blockIdx
.
y
;
const
size_
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
size_
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
in
t
dbias_stride
=
cols
;
const
in
t
dbias_offset_Y
=
blockIdx
.
y
;
const
in
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
in
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
bool
col_out_of_bounds_dbias
=
(
dbias_offset_X
>=
cols
);
if
(
!
col_out_of_bounds_dbias
)
{
dbias_workspace
[
dbias_idx
]
=
thread_partial_dbias
;
...
...
@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
}
// namespace mxfp8_kernel
namespace
nvfp4_kernel
{
using
namespace
ptx
;
constexpr
size_t
SCALE_DIM_Y
=
32
;
constexpr
size_t
SCALE_DIM_X
=
16
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
BUFF_DIM_Y
=
32
;
constexpr
size_t
PACK_SIZE
=
8
;
constexpr
size_t
WAVES
=
SCALE_DIM_X
/
PACK_SIZE
;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__
__forceinline__
fp8e4m3
compute_decoding_scaling_factor
(
const
float
block_amax
,
const
float
S_enc
)
{
constexpr
float
rcp_6f
=
1.0
f
/
6.0
f
;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return
static_cast
<
fp8e4m3
>
(
block_amax
*
rcp_6f
*
S_enc
);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
cast_nvfp4_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output_rowwise
,
const
__grid_constant__
CUtensorMap
tensor_map_output_colwise
,
fp8e4m3
*
const
scales_rowwise_e4m3
,
e8m0_t
*
const
scales_colwise_e8m0
,
const
float
*
noop
,
float
*
const
amax_ptr
,
const
float
*
const
nvfp4_second_stage_scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
scale_stride_colwise
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
ROWWISE_SCALING
=
true
;
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
constexpr
size_t
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
constexpr
size_t
THREADS_X_ROWWISE
=
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
constexpr
size_t
THREADS_Y_ROWWISE
=
THREADS_PER_CHUNK
/
THREADS_X_ROWWISE
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM_Y
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
constexpr
size_t
BUFF_IN_DIM_X
=
CHUNK_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM_X
=
(
CHUNK_DIM_X
*
4
)
/
8
;
// Holds 2 elements of 4-bit size
constexpr
size_t
BUFF_IN_DIM
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM
=
BUFF_DIM_Y
*
BUFF_OUT_DIM_X
;
constexpr
size_t
STAGES
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
ITERATIONS_ROWWISE
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
&&
ROWWISE_SCALING
&&
COLWISE_SCALING
;
const
int
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
CHUNK_DIM_X
/
SCALE_DIM_X
;
const
int
scales_block_offset_Y_colwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
/
SCALE_DIM_Y
;
const
int
scales_block_offset_X_colwise
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
int
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
int
tid_Y_colwise
=
0
;
const
int
tid_X_colwise
=
threadIdx
.
x
;
const
int
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
int
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM_X
;
const
int
thread_offset_Y_colwise
=
tid_Y_colwise
;
const
int
thread_offset_X_colwise
=
tid_X_colwise
;
// Each thread processes two adjacent elements
const
int
row_base_rowwise
=
block_offset_Y
+
thread_offset_Y_rowwise
;
const
int
row_base_colwise
=
block_offset_Y
+
thread_offset_Y_colwise
;
const
int
col_base_colwise
=
block_offset_X
+
thread_offset_X_colwise
;
const
bool
col_out_of_bounds_colwise
=
(
col_base_colwise
>=
cols
);
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
int
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
int
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
const
bool
colwise_scale_is_within_bounds
=
scales_offset_X_colwise
<
cols
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
CHUNK_DIM_Y
*
(
CHUNK_DIM_X
/
SCALE_DIM_X
)
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
/
SCALE_DIM_Y
)
*
CHUNK_DIM_X
*
sizeof
(
fp8e8m0
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
(
ROWWISE_SCALING
?
buff_size_aligned_out_nvfp4
:
0
);
constexpr
size_t
out_mem_colwise_data
=
(
COLWISE_SCALING
?
buff_size_aligned_out_mxfp8
:
0
);
constexpr
size_t
out_mem_rowwise_scales
=
(
ROWWISE_SCALING
?
buff_size_nvfp4_scales
:
0
);
constexpr
size_t
out_mem_colwise_scales
=
(
COLWISE_SCALING
?
buff_size_mxfp8_scales
:
0
);
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_rowwise_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
fp8e4m3
*
out_rowwise_scales_sh
=
reinterpret_cast
<
fp8e4m3
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
e8m0_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
e8m0_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
int
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const
float
S_enc
=
(
nvfp4_second_stage_scale_ptr
==
nullptr
)
?
1.0
f
:
1.0
f
/
(
*
nvfp4_second_stage_scale_ptr
);
float
thread_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_PER_CHUNK
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
int
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
int
buff
=
stage
%
BUFFS_NUM
;
const
int
next_stage
=
stage
+
1
;
const
int
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
int
buff_offset_in
=
buff
*
BUFF_IN_DIM
;
const
int
buff_offset_out
=
buff
*
BUFF_OUT_DIM
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
int
next_buff
=
next_stage
%
BUFFS_NUM
;
const
int
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
int
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
next_buff_offset
=
next_buff
*
BUFF_IN_DIM
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
const
int
shmem_offset_base_colwise
=
buff_offset_in
+
tid_X_colwise
;
block_amax
=
0.0
f
;
float
in_compute_colwise
[
SCALE_DIM_Y
];
IType
in_colwise_IType
[
SCALE_DIM_Y
];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType
block_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
block_amax_f16
=
__hmax
(
block_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
block_amax
=
static_cast
<
float
>
(
block_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
+
i
>=
rows
);
const
bool
out_of_bounds
=
(
col_out_of_bounds_colwise
||
row_out_of_bounds_colwise
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
block_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
int
global_scales_offset_Y
=
scales_offset_Y_colwise
+
stage
;
const
int
global_scales_offset_X
=
scales_offset_X_colwise
;
const
int
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
if
(
colwise_scale_is_within_bounds
)
{
scales_colwise_e8m0
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
// 3. Scale elements
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
float
in
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in
=
static_cast
<
float
>
(
in_colwise_IType
[
i
]);
}
else
{
in
=
in_compute_colwise
[
i
];
}
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
int
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
out_colwise_data_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
if
constexpr
(
ROWWISE_SCALING
)
{
const
int
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS_ROWWISE
;
++
it
)
{
const
int
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
int
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
int
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
const
int
it_offset_Y
=
stage_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
block_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM_X
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
(
!
out_of_bounds
)
{
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
fp8e4m3
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc
);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if
(
rowwise_scale_is_within_bounds
)
{
const
int
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
const
int
scales_offset_X
=
scales_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y
*
scale_stride_rowwise
+
scales_offset_X
;
scales_rowwise_e4m3
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
#else
const
int
shmem_scales_offset_Y
=
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
;
const
int
shmem_scales_offset_X
=
tid_X_rowwise
;
const
int
scale_idx
=
shmem_scales_offset_Y
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
+
shmem_scales_offset_X
;
out_rowwise_scales_sh
[
scale_idx
]
=
S_dec_b_fp8
;
#endif
// Compute "correct" per-block encoding scaling factor
const
float
block_scale_inverse
=
__fdiv_rn
(
S_enc
,
static_cast
<
float
>
(
S_dec_b_fp8
));
// S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
// Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in01
=
in_IType
[
w
].
data
.
elt
[
2
*
e
];
in23
=
in_IType
[
w
].
data
.
elt
[
2
*
e
+
1
];
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
in01
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
];
in01
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
1
];
in23
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
2
];
in23
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
3
];
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
in01
.
x
=
in_compute_rowwise
[
j
];
in01
.
y
=
in_compute_rowwise
[
j
+
1
];
in23
.
x
=
in_compute_rowwise
[
j
+
2
];
in23
.
y
=
in_compute_rowwise
[
j
+
3
];
}
fp4e2m1x4
&
out_quad
=
reinterpret_cast
<
fp4e2m1x4
&>
(
out
.
data
.
elt
[
e
]);
ptx
::
mul_cvt_4x
(
out_quad
,
in01
,
in23
,
block_scale_inverse
);
}
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_rowwise_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
__builtin_assume
(
block_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
buff_offset_nvfp4
=
buff
*
BUFF_OUT_DIM
;
const
int
buff_offset_mxfp8
=
buff
*
BUFF_IN_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_data_sh
[
buff_offset_nvfp4
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_data_sh
[
buff_offset_mxfp8
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if
constexpr
(
ROWWISE_SCALING
)
{
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
threadIdx
.
x
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y_rowwise
*
scale_stride_rowwise
+
scales_offset_X_rowwise
;
const
int
scale_idx_shmem
=
threadIdx
.
x
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
if
((
threadIdx
.
x
<
CHUNK_DIM_Y
)
&&
(
scales_offset_Y_rowwise
<
rows
)
&&
(
scales_offset_X_rowwise
<
(
cols
/
SCALE_DIM_X
)))
{
using
ScalesVec_t
=
Vec
<
fp8e4m3
,
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
>
;
const
ScalesVec_t
&
scales
=
*
reinterpret_cast
<
ScalesVec_t
*>
(
&
out_rowwise_scales_sh
[
scale_idx_shmem
]);
scales
.
store_to
(
&
scales_rowwise_e4m3
[
scale_idx_global
]);
}
}
#endif
float
chunk_amax
=
0.0
f
;
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
chunk_amax
=
reduce_max
<
THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
thread_amax
,
warp_id
);
}
if
(
is_master_thread
&&
amax_ptr
!=
nullptr
)
{
atomicMaxFloat
(
amax_ptr
,
chunk_amax
);
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_kernel
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_THREADS_PER_CHUNK
=
128
;
...
...
@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
}
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
static
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
bool
isFullTile
=
(
N
%
ELEMS_PER_BLOCK
==
0
);
...
...
@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif
}
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
nvfp4_quantize
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
using
namespace
nvfp4_kernel
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 Output tensor must be allocated."
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
if
(
use_colwise_scaling
)
{
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Columnwise scaling tensor must be allocated"
);
}
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_PER_CHUNK
=
128
;
constexpr
size_t
BUFF_DIM_X
=
CHUNK_DIM_X
;
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
size_t
block_size
=
THREADS_PER_CHUNK
;
const
size_t
scale_stride_rowwise
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_colwise
=
use_colwise_scaling
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
1
;
fp8e4m3
*
const
scales_rowwise_e4m3_ptr
=
reinterpret_cast
<
fp8e4m3
*>
(
output
->
scale_inv
.
dptr
);
e8m0_t
*
const
scales_colwise_e8m0_ptr
=
use_colwise_scaling
?
reinterpret_cast
<
e8m0_t
*>
(
output
->
columnwise_scale_inv
.
dptr
)
:
nullptr
;
const
ScalingType
scaling_type
=
use_colwise_scaling
?
ScalingType
::
BIDIMENSIONAL
:
ScalingType
::
ROWWISE
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
nvfp4_second_stage_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
scale
.
dptr
);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const
DType
output_data_type
=
use_colwise_scaling
?
output
->
columnwise_data
.
dtype
:
DType
::
kFloat8E4M3
;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output_data_type
,
OType
,
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_colwise
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output_rowwise
,
output
->
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
use_colwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
OType
)
*
8
);
}
constexpr
size_t
buff_elems
=
nvfp4_kernel
::
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
nvfp4_kernel
::
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
16
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
32
*
sizeof
(
e8m0_t
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
const
size_t
out_rowwise_data_mem
=
buff_size_aligned_out_nvfp4
;
const
size_t
out_colwise_data_mem
=
use_colwise_scaling
?
buff_size_aligned_out_mxfp8
:
0
;
const
size_t
out_rowwise_scales_mem
=
buff_size_nvfp4_scales
;
const
size_t
out_colwise_scales_mem
=
use_colwise_scaling
?
buff_size_mxfp8_scales
:
0
;
const
size_t
out_mem
=
out_rowwise_data_mem
+
out_colwise_data_mem
+
out_rowwise_scales_mem
+
out_colwise_scales_mem
+
TMA_SHMEM_ALIGNMENT
;
const
size_t
dshmem_size
=
in_mem
+
out_mem
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
case
ScalingType
::
BIDIMENSIONAL
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
});
// NOLINT(*)
);
// NOLINT(*)
}
namespace
detail
{
using
Empty
=
transformer_engine
::
Empty
;
...
...
@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto
dbias_tensor
=
convertNVTETensor
(
dbias
);
auto
workspace_tensor
=
convertNVTETensor
(
workspace
);
const
QuantizationConfig
*
quant_config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
quant_config
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
quant_config_cpp
?
quant_config_cpp
->
noop_tensor
:
nullptr
;
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
convertNVTETensorCheck
(
noop
))
:
Tensor
();
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensor
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
// Dispatch to quantization kernel depending on data format
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
output_tensor
->
has_columnwise_data
())
{
...
...
@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
)
{
cast_transpose
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
stream
);
cast_transpose
(
*
input_tensor
,
*
noop_tensor
,
output_tensor
,
stream
);
}
else
{
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
float
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
...
...
@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
}
}
else
if
(
output_tensor
->
has_data
())
{
fp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
input_tensor
,
"input"
);
CheckOutputTensor
(
*
output_tensor
,
"output"
,
false
);
// Choose kernel
int32_t
rows
=
input_tensor
->
flat_first_dim
();
int32_t
cols
=
input_tensor
->
flat_last_dim
();
auto
dtype
=
input_tensor
->
dtype
();
bool
use_optimized_kernel
=
dtype
==
DType
::
kBFloat16
&&
rows
%
32
==
0
&&
cols
%
32
==
0
&&
output_tensor
->
has_data
();
// Launch NVFP4 quantize kernel
if
(
use_optimized_kernel
)
{
if
(
quant_config_cpp
.
nvfp4_2d_quantization
)
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
true
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
else
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
false
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
}
else
{
auto
&
global_amax
=
(
output_tensor
->
amax
.
dptr
!=
nullptr
)
?
output_tensor
->
amax
:
output_tensor
->
columnwise_amax
;
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization"
);
quantize_transpose_vector_blockwise_fp4
(
/*input=*/
input_tensor
->
data
,
/*global_amax=*/
global_amax
,
/*scale_inv=*/
output_tensor
->
scale_inv
,
/*scale_inv_t=*/
output_tensor
->
columnwise_scale_inv
,
/*output=*/
output_tensor
->
data
,
/*output_t=*/
output_tensor
->
columnwise_data
,
/*epsilon=*/
0.0
f
,
/*return_identity=*/
output_tensor
->
has_data
(),
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
/*pow2_scale=*/
false
,
/*swizzled_scale=*/
false
,
/*use_stochastic_rounding=*/
quant_config_cpp
.
stochastic_rounding
,
/*rng_state=*/
quant_config_cpp
.
rng_state
,
/*use_2d_quantization=*/
quant_config_cpp
.
nvfp4_2d_quantization
,
/*noop_tensor=*/
noop_tensor
->
data
,
/*stream=*/
stream
);
}
break
;
}
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
true
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
quantize_transpose_square_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*noop_tensor=*/
noop_tensor
.
data
,
stream
);
/*noop_tensor=*/
noop_tensor
->
data
,
stream
);
break
;
}
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
false
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
...
...
@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
noop_tensor
.
data
,
stream
);
columnwise_option
,
force_pow_2_scales
,
noop_tensor
->
data
,
stream
);
break
;
}
default:
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
063ef88d
...
...
@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
...
...
@@ -28,6 +30,7 @@
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace
transformer_engine
{
...
...
@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#endif
}
#if CUDA_VERSION >= 12080
template
<
typename
OType
>
__global__
void
__launch_bounds__
(
512
)
dequantize_fp4_kernel
(
const
void
*
const
input
,
OType
*
output
,
const
fp8e4m3
*
const
scales
,
const
float
*
const
tensor_amax
,
const
size_t
N
,
const
size_t
M
,
const
size_t
scale_stride
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
x
=
thread_idx
%
M
;
const
size_t
y
=
thread_idx
/
M
;
union
fp4vec
{
uint64_t
vec
;
fp4e2m1x4
small_vec
[
4
];
};
using
OVec
=
Vec
<
OType
,
4
>
;
const
uint64_t
*
const
input_vectorized
=
reinterpret_cast
<
const
uint64_t
*>
(
input
);
OVec
*
output_vec
=
reinterpret_cast
<
OVec
*>
(
output
);
const
size_t
my_index
=
x
+
y
*
M
;
const
size_t
my_scale_index
=
x
+
y
*
scale_stride
;
const
size_t
my_output_index
=
(
x
+
y
*
M
)
*
4
;
fp4vec
value
;
value
.
vec
=
input_vectorized
[
my_index
];
fp8e4m3
scale
=
scales
[
my_scale_index
];
float
amax
=
*
tensor_amax
;
constexpr
float
factor_inv
=
1.0
/
(
6.0
*
448.0
);
float
final_scale
=
static_cast
<
float
>
(
scale
)
*
amax
*
factor_inv
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float4
current
=
static_cast
<
float4
>
(
value
.
small_vec
[
i
]);
OVec
out
;
out
.
data
.
elt
[
0
]
=
static_cast
<
OType
>
(
current
.
x
*
final_scale
);
out
.
data
.
elt
[
1
]
=
static_cast
<
OType
>
(
current
.
y
*
final_scale
);
out
.
data
.
elt
[
2
]
=
static_cast
<
OType
>
(
current
.
z
*
final_scale
);
out
.
data
.
elt
[
3
]
=
static_cast
<
OType
>
(
current
.
w
*
final_scale
);
output_vec
[
my_output_index
+
i
]
=
out
;
}
}
#endif // CUDA_VERSION
void
fp4_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#if CUDA_VERSION >= 12080
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
);
NVTE_CHECK
(
input
.
data
.
dtype
==
DType
::
kFloat4E2M1
,
"Input must have FP4 type."
);
NVTE_CHECK
(
is_high_precision_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
constexpr
int
FP4_BLOCK_SIZE
=
16
;
const
size_t
N
=
input
.
flat_first_dim
();
const
size_t
M
=
input
.
flat_last_dim
();
NVTE_CHECK
(
M
%
FP4_BLOCK_SIZE
==
0
,
"Last dimension of FP4 tensors needs to be divisible by "
,
FP4_BLOCK_SIZE
,
", but got "
,
input
.
data
.
shape
,
"."
);
const
size_t
Mread
=
M
/
FP4_BLOCK_SIZE
;
const
size_t
total
=
N
*
Mread
;
const
size_t
threads
=
512
;
const
size_t
blocks
=
DIVUP
(
total
,
threads
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
dequantize_fp4_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
.
data
.
dptr
,
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
fp8e4m3
*>
(
input
.
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
input
.
amax
.
dptr
),
N
,
Mread
,
input
.
scale_inv
.
shape
.
back
()););
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"CUDA 12.8 or higher is needed for FP4 calculation!"
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace dequantization
namespace
detail
{
...
...
@@ -347,16 +425,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor
(
input
,
"cast_input"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
if
(
is_tensor_scaling
(
input
.
scaling_mode
))
{
switch
(
input
.
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
dequantization
::
fp8_dequantize
(
input
,
output
,
stream
);
}
else
if
(
is_mxfp_scaling
(
input
.
scaling_mode
))
{
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
if
(
is_supported_by_CC_100
())
{
dequantization
::
mxfp8_dequantize
(
input
,
output
,
stream
);
}
else
{
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
}
else
{
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
dequantization
::
fp4_dequantize
(
input
,
output
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
}
}
...
...
transformer_engine/common/util/logging.h
View file @
063ef88d
...
...
@@ -23,6 +23,8 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
...
...
@@ -147,4 +149,12 @@
#endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
transformer_engine/common/util/math.h
View file @
063ef88d
...
...
@@ -11,6 +11,11 @@ namespace transformer_engine {
struct
Empty
{};
struct
ClampedSwiGLUParam
{
float
limit
;
float
alpha
=
1.702
f
;
// Default value for QuickGELU
};
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
gelu
(
const
IType
val
,
const
Empty
&
)
{
const
float
cval
=
val
;
...
...
@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return
s
*
(
1.
f
-
s
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
qgelu_with_alpha
(
const
IType
val
,
const
float
alpha
)
{
const
float
cval
=
val
;
Empty
e
=
{};
return
cval
*
sigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
qgelu
(
const
IType
val
,
const
Empty
&
e
)
{
return
qgelu_with_alpha
<
OType
,
IType
>
(
val
,
1.702
f
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dqgelu_with_alpha
(
const
IType
val
,
const
float
alpha
)
{
const
float
cval
=
val
;
return
cval
*
sigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
);
Empty
e
=
{};
return
alpha
*
cval
*
dsigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dqgelu
(
const
IType
val
,
const
Empty
&
e
)
{
const
float
cval
=
val
;
return
1.702
f
*
cval
*
dsigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
);
return
dqgelu_with_alpha
<
OType
,
IType
>
(
val
,
1.702
f
);
}
template
<
typename
OType
,
typename
IType
>
...
...
@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return
cval
*
sigmoid
<
float
,
float
>
(
cval
,
e
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
clamped_silu
(
const
IType
val
,
const
ClampedSwiGLUParam
&
p
)
{
const
float
cval
=
min
(
p
.
limit
,
static_cast
<
float
>
(
val
));
// Clamping
return
qgelu_with_alpha
<
OType
,
float
>
(
cval
,
p
.
alpha
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dsilu
(
const
IType
val
,
const
Empty
&
e
)
{
const
float
cval
=
val
;
return
cval
*
dsigmoid
<
float
,
float
>
(
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
cval
,
e
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
clamped_dsilu
(
const
IType
val
,
const
ClampedSwiGLUParam
&
p
)
{
const
bool
dclamp_val
=
static_cast
<
float
>
(
val
)
<=
p
.
limit
;
const
float
clamp_val
=
min
(
static_cast
<
float
>
(
val
),
p
.
limit
);
const
float
dsilu_val
=
dqgelu_with_alpha
<
OType
,
float
>
(
clamp_val
,
p
.
alpha
);
return
dclamp_val
?
dsilu_val
:
0.0
f
;
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
relu
(
IType
value
,
const
Empty
&
)
{
return
fmaxf
(
value
,
0.
f
);
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
#if CUDA_VERSION > 12080
namespace
nvfp4_transpose
{
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
using
namespace
ptx
;
using
nvfp4_scale_t
=
fp8e4m3
;
constexpr
size_t
SCALE_DIM
=
16
;
// NVFP4 block (x16 elts)
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_NUM
=
128
;
constexpr
size_t
SCALES_PER_CHUNK_Y
=
CHUNK_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_THREAD
=
2
*
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
SCALE_DIM
/
THREADS_NUM
;
constexpr
size_t
RNG_GENS_PER_THREAD
=
SCALES_PER_THREAD
/
4
;
// Each call generates 4x uint32_t random numbers
constexpr
size_t
TILE_DIM_Y
=
32
;
constexpr
size_t
TILE_DIM_X
=
128
;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr
size_t
SCALES_PER_TILE_Y
=
TILE_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_TILE_X
=
TILE_DIM_X
/
SCALE_DIM
;
// 128 / 16 = 8
constexpr
size_t
TILES_Y
=
CHUNK_DIM_Y
/
TILE_DIM_Y
;
constexpr
size_t
TILES_X
=
CHUNK_DIM_X
/
TILE_DIM_X
;
constexpr
size_t
STAGES
=
TILES_Y
*
TILES_X
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
BUFF_DIM_Y
=
TILE_DIM_Y
;
constexpr
size_t
BUFF_DIM_X
=
TILE_DIM_X
;
constexpr
size_t
BUFF_SIZE
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
BUFF_SIZE_TOTAL
=
BUFF_SIZE
*
BUFFS_NUM
;
// Input buffer (BF16)
constexpr
size_t
BUFF_IN_DIM_Y
=
BUFF_DIM_Y
;
constexpr
size_t
BUFF_IN_DIM_X
=
BUFF_DIM_X
;
constexpr
size_t
BUFF_IN_SIZE
=
BUFF_IN_DIM_Y
*
BUFF_IN_DIM_X
;
// Output buffer (NVFP4)
constexpr
size_t
BUFF_OUT_DIM_Y
=
BUFF_DIM_Y
;
constexpr
size_t
BUFF_OUT_DIM_X
=
(
BUFF_DIM_X
*
4
)
/
8
;
constexpr
size_t
BUFF_OUT_SIZE
=
BUFF_OUT_DIM_Y
*
BUFF_OUT_DIM_X
;
// Output transpose buffer (NVFP4)
constexpr
size_t
BUFF_OUT_T_DIM_Y
=
BUFF_DIM_X
;
constexpr
size_t
BUFF_OUT_T_DIM_X
=
(
BUFF_DIM_Y
*
4
)
/
8
;
constexpr
size_t
BUFF_OUT_T_SIZE
=
BUFF_OUT_T_DIM_Y
*
BUFF_OUT_T_DIM_X
;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr
size_t
PACK_SIZE
=
8
;
constexpr
size_t
WAVES
=
SCALE_DIM
/
PACK_SIZE
;
constexpr
size_t
SCALING_FACTORS_PER_TILE_X
=
TILE_DIM_X
/
SCALE_DIM
;
constexpr
size_t
THREADS_X_ROWWISE
=
SCALING_FACTORS_PER_TILE_X
;
// 128 / 16 = 8
constexpr
size_t
THREADS_Y_ROWWISE
=
THREADS_NUM
/
THREADS_X_ROWWISE
;
// 128 / 8 = 16
constexpr
size_t
ITERATIONS_NORMAL
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
// 32/ 16 = 2
constexpr
size_t
ITERATIONS_TRANSPOSE
=
BUFF_IN_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
BUFF_OUT_IT_OFFSET
=
BUFF_OUT_T_DIM_X
/
ITERATIONS_TRANSPOSE
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM
;
// 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__
__forceinline__
nvfp4_scale_t
compute_decoding_scaling_factor
(
const
float
block_amax
,
const
float
S_enc
)
{
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using
namespace
detail
;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
// 6.0f;
const
float
S_dec_b
=
block_amax
/
fp4_max
*
S_enc
;
return
static_cast
<
nvfp4_scale_t
>
(
fminf
(
S_dec_b
,
TypeExtrema
<
float
>::
max
));
}
// Compute the global encode scale factor for a given global amax
__device__
__forceinline__
float
compute_global_encode_scaling_factor_FP4
(
const
float
global_amax
)
{
using
namespace
detail
;
constexpr
float
fp8_max
=
TypeExtrema
<
fp8e4m3
>::
max
;
// 448.0f;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
// 6.0f;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
TypeExtrema
<
float
>::
max
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.0
f
||
global_encode_scale
==
0.0
f
)
{
return
1.0
f
;
}
return
global_encode_scale
;
}
__device__
__forceinline__
uint32_t
get_rbits
(
RNG
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
if
(
rnd_idx
==
4
)
{
rnd_idx
=
0
;
curanddx
::
uniform_bits
dist
;
random_uint4
=
dist
.
generate4
(
rng
);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const
uint32_t
*
const
rbits_arr
=
reinterpret_cast
<
uint32_t
*>
(
&
random_uint4
);
const
uint32_t
rbits
=
rbits_arr
[
rnd_idx
++
];
return
rbits
;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3;
\n\t
"
// mind the shuffled elements order
"}"
:
"=h"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
template
<
bool
USE_STOCHASTIC_ROUNDING
>
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
return
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
in_4x
,
scale
,
rbits
);
}
else
{
return
mul_cvt_bf16_to_fp4_4x_with_rn
(
in_4x
,
scale
,
rbits
);
}
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4;
\n\t
"
// mind the shuffled elements order
"}"
:
"=h"
(
out_4x
)
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
template
<
bool
USE_STOCHASTIC_ROUNDING
>
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
return
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
in01
,
in23
,
scale
,
rbits
);
}
else
{
return
mul_cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
scale
,
rbits
);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
nvfp4_transpose_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
nvfp4_scale_t
*
const
scales_ptr
,
nvfp4_scale_t
*
const
scales_t_ptr
,
const
float
*
noop
,
const
float
*
const
amax_rowwise_ptr
,
const
float
*
const
amax_colwise_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride
,
const
size_t
scale_stride_t
,
const
size_t
*
rng_state
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
const
size_t
rng_sequence
=
threadIdx
.
x
+
blockIdx
.
x
*
THREADS_NUM
+
blockIdx
.
y
*
gridDim
.
x
*
THREADS_NUM
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
USE_STOCHASTIC_ROUNDING
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
;
const
size_t
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_X_t
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_rows
=
rows
-
block_offset_Y
;
const
size_t
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
SCALES_PER_CHUNK_X
;
const
size_t
scales_block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
scales_block_offset_X_t
=
blockIdx
.
y
*
SCALES_PER_CHUNK_Y
;
const
size_t
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
size_t
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
size_t
tid_X_colwise
=
threadIdx
.
x
;
const
size_t
tid_Y_t
=
tid_X_colwise
;
// const size_t tid_X_t = 0;
const
size_t
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
size_t
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM
;
const
size_t
thread_offset_X_colwise
=
tid_X_colwise
;
const
size_t
row_base_rowwise
=
block_offset_Y
+
thread_offset_Y_rowwise
;
const
size_t
row_base_colwise
=
block_offset_Y
;
const
size_t
col_base_colwise
=
block_offset_X
+
thread_offset_X_colwise
;
const
bool
col_out_of_bounds_colwise
=
(
col_base_colwise
>=
cols
);
const
size_t
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
size_t
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
size_t
scales_offset_Y_t
=
scales_block_offset_Y_t
+
tid_Y_t
;
const
size_t
scales_offset_X_t
=
scales_block_offset_X_t
;
const
size_t
SFs_per_row
=
cols
/
SCALE_DIM
;
const
bool
rowwise_scale_is_within_bounds_X
=
scales_offset_X_rowwise
<
SFs_per_row
;
const
bool
colwise_scale_is_within_bounds_Y
=
scales_offset_Y_t
<
cols
;
// Helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_colwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_rowwise_scales
=
0
;
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
fp4e2m1x2
*
out_t_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
nvfp4_scale_t
*
out_rowwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
nvfp4_scale_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const
float
S_enc_rowwise
=
(
amax_rowwise_ptr
==
nullptr
)
?
1.0
f
:
compute_global_encode_scaling_factor_FP4
(
*
amax_rowwise_ptr
);
// NOTE: This is to match with how emulation code was written.
const
float
S_dec_rowwise
=
1.0
/
S_enc_rowwise
;
const
float
S_enc_colwise
=
(
amax_colwise_ptr
==
nullptr
)
?
S_enc_rowwise
:
compute_global_encode_scaling_factor_FP4
(
*
amax_colwise_ptr
);
const
float
S_dec_colwise
=
1.0
/
S_enc_colwise
;
float
thread_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_NUM
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
size_t
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
size_t
buff
=
stage
%
BUFFS_NUM
;
const
size_t
next_stage
=
stage
+
1
;
const
size_t
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
size_t
buff_offset_in
=
buff
*
BUFF_IN_SIZE
;
const
size_t
buff_offset_out
=
buff
*
BUFF_OUT_SIZE
;
const
size_t
buff_offset_out_t
=
buff
*
BUFF_OUT_T_SIZE
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
size_t
next_buff
=
next_stage
%
BUFFS_NUM
;
const
size_t
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
size_t
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
next_buff_offset
=
next_buff
*
BUFF_IN_SIZE
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
// COLWISE scaling
if
constexpr
(
RETURN_TRANSPOSE
)
{
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_TRANSPOSE
;
++
it
)
{
const
size_t
in_thread_offset_Y
=
0
+
it
*
SCALE_DIM
;
const
size_t
in_thread_offset_X
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_Y
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_X
=
0
+
it
*
BUFF_OUT_IT_OFFSET
;
const
size_t
shmem_offset_base_colwise_in
=
buff_offset_in
+
in_thread_offset_Y
*
BUFF_IN_DIM_X
+
in_thread_offset_X
;
const
size_t
shmem_offset_base_colwise_out_t
=
buff_offset_out_t
+
out_t_thread_offset_Y
*
BUFF_OUT_T_DIM_X
+
out_t_thread_offset_X
;
block_amax
=
0.0
f
;
float
in_compute_colwise
[
SCALE_DIM
];
IType
in_colwise_IType
[
SCALE_DIM
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType
block_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
block_amax_f16
=
__hmax
(
block_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
block_amax
=
static_cast
<
float
>
(
block_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
+
i
>=
rows
);
const
bool
out_of_bounds
=
(
col_out_of_bounds_colwise
||
row_out_of_bounds_colwise
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_colwise
);
// Store scaling factors through SHMEM
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
+
stage
*
ITERATIONS_TRANSPOSE
+
it
;
out_colwise_scales_sh
[
scale_idx_sh
]
=
S_dec_b_fp8
;
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_colwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
fp4e2m1x4
regs
[
SCALE_DIM
/
4
];
#pragma unroll
for
(
int
e
=
0
;
e
<
SCALE_DIM
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_colwise_IType
[
4
*
e
]);
regs
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
float2
in01
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
]);
const
float2
in23
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
+
2
]);
regs
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
int
group
=
thread_lane
/
16
;
uint32_t
val
[
2
];
uint32_t
*
regs_4x
=
reinterpret_cast
<
uint32_t
*>
(
regs
);
// Helps reducing bank conflicts
switch
(
group
)
{
case
0
:
val
[
0
]
=
regs_4x
[
0
];
val
[
1
]
=
regs_4x
[
1
];
break
;
case
1
:
val
[
0
]
=
regs_4x
[
1
];
val
[
1
]
=
regs_4x
[
0
];
break
;
}
uint32_t
*
out_t_data_sh_as_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
&
out_t_data_sh
[
shmem_offset_base_colwise_out_t
]);
out_t_data_sh_as_uint32_t
[
group
]
=
val
[
0
];
// idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t
[(
group
+
1
)
&
1
]
=
val
[
1
];
// idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const
size_t
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_NORMAL
;
++
it
)
{
const
size_t
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
size_t
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
const
size_t
it_offset_Y
=
stage_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
block_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
(
!
out_of_bounds
)
{
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
size_t
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_rowwise
);
// Check boundaries
const
size_t
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage
*
BUFF_DIM_Y
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx_global
=
scales_offset_Y
*
scale_stride
+
scales_offset_X
;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const
bool
rowwise_scale_is_within_bounds_Y
=
(
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
)
<
chunk_rows
;
if
(
rowwise_scale_is_within_bounds_X
&&
rowwise_scale_is_within_bounds_Y
)
{
scales_ptr
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_rowwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_IType
[
w
].
data
.
elt
[
2
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_cached
[
w
].
data
.
elt
[
4
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
const
float2
in01
=
make_float2
(
in_compute_rowwise
[
j
],
in_compute_rowwise
[
j
+
1
]);
const
float2
in23
=
make_float2
(
in_compute_rowwise
[
j
+
2
],
in_compute_rowwise
[
j
+
3
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
global_offset_Y_t
=
block_offset_Y_t
;
const
size_t
global_offset_X_t
=
block_offset_X_t
+
stage_offset_Y
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_data_sh
[
buff_offset_out
]));
if
constexpr
(
RETURN_TRANSPOSE
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
global_offset_X_t
,
global_offset_Y_t
,
reinterpret_cast
<
uint64_t
*>
(
&
out_t_data_sh
[
buff_offset_out_t
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
// end of stages
// Vectorized store scaling factors through SHMEM
if
(
RETURN_TRANSPOSE
&&
colwise_scale_is_within_bounds_Y
)
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_Y
>
;
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
;
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
&
out_colwise_scales_sh
[
scale_idx_sh
]);
const
size_t
scale_idx_global
=
scales_offset_Y_t
*
scale_stride_t
+
scales_offset_X_t
;
const
size_t
count
=
// number of scales in Y dimension of this chunk
(
chunk_rows
>=
CHUNK_DIM_Y
)
?
SCALES_PER_CHUNK_Y
:
(
chunk_rows
/
SCALE_DIM
);
nvfp4_scale_t
*
dst
=
&
scales_t_ptr
[
scale_idx_global
];
constexpr
size_t
vec_bytes
=
SCALES_PER_CHUNK_Y
*
sizeof
(
nvfp4_scale_t
);
if
(
count
==
SCALES_PER_CHUNK_Y
&&
(
reinterpret_cast
<
uintptr_t
>
(
dst
)
%
vec_bytes
==
0
))
{
// Fast path: vectorized store when destination is properly aligned
scales_vec
.
store_to
(
dst
);
}
else
{
// Safe path: element-wise store for tails or unaligned destinations
scales_vec
.
store_to_elts
(
dst
,
0
,
count
);
}
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#else
NVTE_DEVICE_ERROR
(
"sm_100 or higher is required."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
nvfp4_transpose_kernel_2D
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
nvfp4_scale_t
*
const
scales_ptr
,
nvfp4_scale_t
*
const
scales_t_ptr
,
const
float
*
noop
,
const
float
*
const
amax_rowwise_ptr
,
const
float
*
const
amax_colwise_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride
,
const
size_t
scale_stride_t
,
const
size_t
*
rng_state
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
const
size_t
rng_sequence
=
threadIdx
.
x
+
blockIdx
.
x
*
THREADS_NUM
+
blockIdx
.
y
*
gridDim
.
x
*
THREADS_NUM
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
USE_STOCHASTIC_ROUNDING
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr
size_t
BLOCK_DIM
=
16
;
constexpr
size_t
BLOCKS_PER_TILE_Y
=
TILE_DIM_Y
/
BLOCK_DIM
;
// 32/16 = 2
constexpr
size_t
BLOCKS_PER_TILE_X
=
TILE_DIM_X
/
BLOCK_DIM
;
// 128/16 = 8
constexpr
size_t
ITERATIONS_BLOCK
=
2
;
// iterations to calculate 2d block amaxes of 1 tile
constexpr
size_t
BLOCKS_PER_WARP
=
BLOCKS_PER_TILE_X
/
(
THREADS_NUM
/
32
);
// 8 / (128/32) = 2
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
;
const
size_t
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_X_t
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_rows
=
rows
-
block_offset_Y
;
const
size_t
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
SCALES_PER_CHUNK_X
;
const
size_t
scales_block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
scales_block_offset_X_t
=
blockIdx
.
y
*
SCALES_PER_CHUNK_Y
;
const
size_t
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
size_t
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
size_t
tid_X_colwise
=
threadIdx
.
x
;
const
size_t
tid_Y_t
=
tid_X_colwise
;
const
size_t
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
size_t
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM
;
const
size_t
thread_offset_X_colwise
=
tid_X_colwise
;
const
size_t
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
size_t
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
size_t
scales_offset_Y_t
=
scales_block_offset_Y_t
+
tid_Y_t
;
const
size_t
scales_offset_X_t
=
scales_block_offset_X_t
;
const
size_t
SFs_per_row
=
cols
/
SCALE_DIM
;
const
bool
rowwise_scale_is_within_bounds_X
=
scales_offset_X_rowwise
<
SFs_per_row
;
const
bool
colwise_scale_is_within_bounds_Y
=
scales_offset_Y_t
<
cols
;
// Helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_colwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_rowwise_scales
=
0
;
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
fp4e2m1x2
*
out_t_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
nvfp4_scale_t
*
out_rowwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
nvfp4_scale_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const
float
S_enc_rowwise
=
(
amax_rowwise_ptr
==
nullptr
)
?
1.0
f
:
compute_global_encode_scaling_factor_FP4
(
*
amax_rowwise_ptr
);
// NOTE: This is to match with how emulation code was written.
const
float
S_dec_rowwise
=
1.0
/
S_enc_rowwise
;
const
float
S_enc_colwise
=
(
amax_colwise_ptr
==
nullptr
)
?
S_enc_rowwise
:
compute_global_encode_scaling_factor_FP4
(
*
amax_colwise_ptr
);
const
float
S_dec_colwise
=
1.0
/
S_enc_colwise
;
const
size_t
warp_id
=
threadIdx
.
x
/
32
;
const
size_t
lane_id
=
threadIdx
.
x
%
32
;
float
thread_amax
=
0.0
f
;
const
size_t
block_in_warp
=
lane_id
/
BLOCKS_PER_WARP
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
__shared__
__align__
(
16
)
float
block_amax_matrix
[
BLOCKS_PER_TILE_Y
][
BLOCKS_PER_TILE_X
+
1
];
// Helper function for warp reduction
auto
warp_reduce_amax
=
[](
float
thread_amax
,
int
block_in_warp
)
->
float
{
#pragma unroll
for
(
int
delta
=
8
;
delta
>=
1
;
delta
/=
2
)
{
float
other_amax
=
__shfl_xor_sync
(
0xffffffff
,
thread_amax
,
delta
);
thread_amax
=
fmaxf
(
thread_amax
,
other_amax
);
}
return
thread_amax
;
};
initialize_barriers
<
STAGES
,
THREADS_NUM
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
size_t
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
size_t
buff
=
stage
%
BUFFS_NUM
;
const
size_t
next_stage
=
stage
+
1
;
const
size_t
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
size_t
buff_offset_in
=
buff
*
BUFF_IN_SIZE
;
const
size_t
buff_offset_out
=
buff
*
BUFF_OUT_SIZE
;
const
size_t
buff_offset_out_t
=
buff
*
BUFF_OUT_T_SIZE
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
size_t
next_buff
=
next_stage
%
BUFFS_NUM
;
const
size_t
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
size_t
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
next_buff_offset
=
next_buff
*
BUFF_IN_SIZE
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
#pragma unroll
for
(
size_t
block_iter
=
0
;
block_iter
<
ITERATIONS_BLOCK
;
++
block_iter
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
const
size_t
block_in_tile_y
=
block_iter
;
const
size_t
block_in_tile_x
=
threadIdx
.
x
/
BLOCK_DIM
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
for
(
int
elem
=
0
;
elem
<
BLOCK_DIM
;
elem
+=
2
)
{
const
size_t
elem_0_row
=
block_iter
*
BLOCK_DIM
+
elem
;
const
size_t
elem_1_row
=
elem_0_row
+
1
;
const
size_t
elem_0_col
=
warp_id
*
BLOCKS_PER_WARP
*
BLOCK_DIM
+
lane_id
;
const
size_t
elem_1_col
=
elem_0_col
;
const
size_t
shmem_offset_0
=
buff_offset_in
+
elem_0_row
*
BUFF_IN_DIM_X
+
elem_0_col
;
const
size_t
shmem_offset_1
=
buff_offset_in
+
elem_1_row
*
BUFF_IN_DIM_X
+
elem_1_col
;
IType2
val_2x
;
val_2x
.
x
=
in_sh
[
shmem_offset_0
];
val_2x
.
y
=
in_sh
[
shmem_offset_1
];
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
val_2x
);
}
thread_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
{
for
(
int
elem
=
0
;
elem
<
BLOCK_DIM
;
++
elem
)
{
const
size_t
elem_row
=
block_iter
*
BLOCK_DIM
+
elem
;
const
size_t
elem_col
=
warp_id
*
BLOCKS_PER_WARP
*
BLOCK_DIM
+
lane_id
;
// Bounds checking
const
bool
row_out_of_bounds
=
(
block_offset_Y
+
stage_offset_Y
+
elem_row
>=
rows
);
const
bool
col_out_of_bounds
=
(
block_offset_X
+
elem_col
>=
cols
);
if
(
!
row_out_of_bounds
&&
!
col_out_of_bounds
)
{
const
size_t
shmem_offset
=
buff_offset_in
+
elem_row
*
BUFF_IN_DIM_X
+
elem_col
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset
]
=
static_cast
<
IType
>
(
elt
);
}
thread_amax
=
fmaxf
(
thread_amax
,
fabsf
(
elt
));
}
}
}
// Warp reduction to get block amax
block_amax
=
warp_reduce_amax
(
thread_amax
,
block_in_warp
);
if
(
lane_id
==
0
||
lane_id
==
16
)
{
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
]
=
block_amax
;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads
();
// COLWISE scaling
if
constexpr
(
RETURN_TRANSPOSE
)
{
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_TRANSPOSE
;
++
it
)
{
const
size_t
block_in_tile_y
=
it
;
const
size_t
block_in_tile_x
=
threadIdx
.
x
/
BLOCK_DIM
;
const
size_t
in_thread_offset_Y
=
0
+
it
*
SCALE_DIM
;
const
size_t
in_thread_offset_X
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_Y
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_X
=
0
+
it
*
BUFF_OUT_IT_OFFSET
;
const
size_t
shmem_offset_base_colwise_in
=
buff_offset_in
+
in_thread_offset_Y
*
BUFF_IN_DIM_X
+
in_thread_offset_X
;
const
size_t
shmem_offset_base_colwise_out_t
=
buff_offset_out_t
+
out_t_thread_offset_Y
*
BUFF_OUT_T_DIM_X
+
out_t_thread_offset_X
;
block_amax
=
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
];
float
in_compute_colwise
[
SCALE_DIM
];
IType
in_colwise_IType
[
SCALE_DIM
];
// 3. Scale elements
// Load data in
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
}
}
else
{
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_colwise
);
// // Store scaling factors through SHMEM
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
+
stage
*
ITERATIONS_TRANSPOSE
+
it
;
out_colwise_scales_sh
[
scale_idx_sh
]
=
S_dec_b_fp8
;
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_colwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
fp4e2m1x4
regs
[
SCALE_DIM
/
4
];
#pragma unroll
for
(
int
e
=
0
;
e
<
SCALE_DIM
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_colwise_IType
[
4
*
e
]);
regs
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
float2
in01
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
]);
const
float2
in23
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
+
2
]);
regs
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
int
group
=
thread_lane
/
16
;
uint32_t
val
[
2
];
uint32_t
*
regs_4x
=
reinterpret_cast
<
uint32_t
*>
(
regs
);
// Helps reducing bank conflicts
switch
(
group
)
{
case
0
:
val
[
0
]
=
regs_4x
[
0
];
val
[
1
]
=
regs_4x
[
1
];
break
;
case
1
:
val
[
0
]
=
regs_4x
[
1
];
val
[
1
]
=
regs_4x
[
0
];
break
;
}
uint32_t
*
out_t_data_sh_as_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
&
out_t_data_sh
[
shmem_offset_base_colwise_out_t
]);
out_t_data_sh_as_uint32_t
[
group
]
=
val
[
0
];
// idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t
[(
group
+
1
)
&
1
]
=
val
[
1
];
// idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const
size_t
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_NORMAL
;
++
it
)
{
const
size_t
block_in_tile_y
=
it
;
const
size_t
block_in_tile_x
=
tid_X_rowwise
;
const
size_t
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
size_t
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
block_amax
=
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
];
float
in_compute_rowwise
[
SCALE_DIM
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
}
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
size_t
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_rowwise
);
// Check boundaries
const
size_t
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage
*
BUFF_DIM_Y
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx_global
=
scales_offset_Y
*
scale_stride
+
scales_offset_X
;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const
bool
rowwise_scale_is_within_bounds_Y
=
(
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
)
<
chunk_rows
;
if
(
rowwise_scale_is_within_bounds_X
&&
rowwise_scale_is_within_bounds_Y
)
{
scales_ptr
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_rowwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_IType
[
w
].
data
.
elt
[
2
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_cached
[
w
].
data
.
elt
[
4
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
const
float2
in01
=
make_float2
(
in_compute_rowwise
[
j
],
in_compute_rowwise
[
j
+
1
]);
const
float2
in23
=
make_float2
(
in_compute_rowwise
[
j
+
2
],
in_compute_rowwise
[
j
+
3
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
global_offset_Y_t
=
block_offset_Y_t
;
const
size_t
global_offset_X_t
=
block_offset_X_t
+
stage_offset_Y
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_data_sh
[
buff_offset_out
]));
if
constexpr
(
RETURN_TRANSPOSE
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
global_offset_X_t
,
global_offset_Y_t
,
reinterpret_cast
<
uint64_t
*>
(
&
out_t_data_sh
[
buff_offset_out_t
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
// end of stages
// Vectorized store scaling factors through SHMEM
if
(
RETURN_TRANSPOSE
&&
colwise_scale_is_within_bounds_Y
)
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_Y
>
;
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
;
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
&
out_colwise_scales_sh
[
scale_idx_sh
]);
const
size_t
scale_idx_global
=
scales_offset_Y_t
*
scale_stride_t
+
scales_offset_X_t
;
const
size_t
count
=
// number of scales in Y dimension of this chunk
(
chunk_rows
>=
CHUNK_DIM_Y
)
?
SCALES_PER_CHUNK_Y
:
(
chunk_rows
/
SCALE_DIM
);
nvfp4_scale_t
*
dst
=
&
scales_t_ptr
[
scale_idx_global
];
constexpr
size_t
vec_bytes
=
SCALES_PER_CHUNK_Y
*
sizeof
(
nvfp4_scale_t
);
if
(
count
==
SCALES_PER_CHUNK_Y
&&
(
reinterpret_cast
<
uintptr_t
>
(
dst
)
%
vec_bytes
==
0
))
{
// Fast path: vectorized store when destination is properly aligned
scales_vec
.
store_to
(
dst
);
}
else
{
// Safe path: element-wise store for tails or unaligned destinations
scales_vec
.
store_to_elts
(
dst
,
0
,
count
);
}
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
bool
use_2d_quantization
>
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if CUDA_VERSION > 12080
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool
return_transpose
=
output
->
has_columnwise_data
();
using
namespace
nvfp4_transpose
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
,
false
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 output tensor must be allocated."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
if
(
return_transpose
)
{
NVTE_CHECK
(
output
->
has_columnwise_data
(),
"NVFP4 transposed output tensor must be allocated."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
columnwise_data
.
dtype
),
"Transposed output must have FP4 type."
);
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Transposed scaling tensor must be allocated"
);
}
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
NVTE_CHECK
(
rows
%
32
==
0
,
"Number of tensor rows must be a multiple of 32"
);
// 16B alignment for TMA
NVTE_CHECK
(
cols
%
32
==
0
,
"Number of tensor cols must be a multiple of 32"
);
// 16B alignment for TMA
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
size_t
block_size
=
THREADS_NUM
;
const
size_t
scale_stride
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_transpose
=
return_transpose
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
0
;
nvfp4_scale_t
*
const
scales_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
scale_inv
.
dptr
);
nvfp4_scale_t
*
const
scales_transpose_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
columnwise_scale_inv
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
amax_rowwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
amax
.
dptr
);
const
float
*
const
amax_colwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
columnwise_amax
.
dptr
);
const
NVTETensor
rng_state_tensor
=
(
quant_config
!=
nullptr
)
?
quant_config
->
rng_state
:
nullptr
;
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
using
IType
=
bf16
;
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_transpose
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
return_transpose
)
{
create_2D_tensor_map
(
tensor_map_output_transpose
,
output
->
columnwise_data
,
cols
,
rows
,
BUFF_DIM_X
,
BUFF_DIM_Y
,
rows
,
0
,
4
);
}
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
16
*
sizeof
(
nvfp4_scale_t
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_data_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_data_transpose_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_scales_transpose_mem
=
buff_size_scales
;
constexpr
size_t
out_mem
=
out_data_mem
+
out_data_transpose_mem
;
constexpr
size_t
dshmem_size
=
in_mem
+
out_mem
+
out_scales_transpose_mem
+
TMA_SHMEM_ALIGNMENT
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
USE_STOCHASTIC_ROUNDING
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
RETURN_TRANSPOSE
,
{
auto
kernel
=
nvfp4_transpose_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
USE_STOCHASTIC_ROUNDING
,
RETURN_TRANSPOSE
>
;
if
constexpr
(
use_2d_quantization
)
{
kernel
=
nvfp4_transpose_kernel_2D
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
USE_STOCHASTIC_ROUNDING
,
RETURN_TRANSPOSE
>
;
}
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
kernel
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output
,
tensor_map_output_transpose
,
scales_ptr
,
scales_transpose_ptr
,
noop_ptr
,
amax_rowwise_ptr
,
amax_colwise_ptr
,
rows
,
cols
,
scale_stride
,
scale_stride_transpose
,
rng_state
);
}););
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION > 12080
}
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
transformer_engine/common/util/ptx.cuh
View file @
063ef88d
...
...
@@ -14,6 +14,10 @@
#include <cuda.h>
#include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace
transformer_engine
{
namespace
ptx
{
...
...
@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return
__int_as_float
(
biased_exp
<<
FP32_MANTISSA_BITS
);
}
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
#if
((__
CUDA_ARCH_HAS_FEATURE_
_(
SM10
0
_ALL
)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
#if CUDA_ARCH_HAS_FEATURE_SM10
X
_ALL
uint16_t
out
;
asm
volatile
(
"{
\n
"
...
...
@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T
y
;
};
template
<
typename
T
>
struct
FPx4
{
T
x1
;
T
x2
;
T
x3
;
T
x4
;
};
template
<
typename
T
>
struct
Type2x
{};
template
<
>
struct
Type2x
<
float
>
{
using
type
=
float2
;
};
template
<
>
struct
Type2x
<
bf16
>
{
using
type
=
__nv_bfloat162
;
};
template
<
>
struct
Type2x
<
fp16
>
{
using
type
=
__half2
;
};
using
floatx2
=
FPx2
<
float
>
;
using
bf16x2
=
FPx2
<
bf16
>
;
using
fp16x2
=
FPx2
<
fp16
>
;
using
fp8e4m3x2
=
FPx2
<
fp8e4m3
>
;
using
fp8e5m2x2
=
FPx2
<
fp8e5m2
>
;
using
floatx4
=
FPx4
<
float
>
;
using
bf16x4
=
FPx4
<
bf16
>
;
using
fp16x4
=
FPx4
<
fp16
>
;
using
fp8e4m3x4
=
FPx4
<
fp8e4m3
>
;
using
fp8e5m2x4
=
FPx4
<
fp8e5m2
>
;
static_assert
(
sizeof
(
floatx2
)
==
8
);
static_assert
(
sizeof
(
bf16x2
)
==
4
);
static_assert
(
sizeof
(
fp16x2
)
==
4
);
static_assert
(
sizeof
(
fp8e4m3x2
)
==
2
);
static_assert
(
sizeof
(
fp8e5m2x2
)
==
2
);
#if CUDA_VERSION >= 12080
using
fp4e2m1
=
__nv_fp4_e2m1
;
using
fp4e2m1x2
=
__nv_fp4x2_e2m1
;
using
fp4e2m1x4
=
__nv_fp4x4_e2m1
;
static_assert
(
sizeof
(
fp4e2m1x2
)
==
1
);
static_assert
(
sizeof
(
fp4e2m1x4
)
==
2
);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template
<
typename
Tx2
>
__device__
__forceinline__
void
mul_cvt_4x
(
fp4e2m1x4
&
out
,
const
Tx2
&
in01
,
const
Tx2
&
in23
,
const
float
scale
)
{
const
float
x0
=
static_cast
<
float
>
(
in01
.
x
)
*
scale
;
const
float
x1
=
static_cast
<
float
>
(
in01
.
y
)
*
scale
;
const
float
x2
=
static_cast
<
float
>
(
in23
.
x
)
*
scale
;
const
float
x3
=
static_cast
<
float
>
(
in23
.
y
)
*
scale
;
out
=
fp4e2m1x4
(
make_float4
(
x0
,
x1
,
x2
,
x3
));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2)
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
floatx2
&
in
,
const
floatx2
&
scale
)
{
...
...
@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
}
#endif //
#if
(defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
// namespace ptx
...
...
transformer_engine/common/util/pybind_helper.h
View file @
063ef88d
...
...
@@ -27,6 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
...
...
@@ -41,6 +42,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
...
...
transformer_engine/common/util/vectorized_pointwise.h
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@
#include "../common.h"
#include "../utils.cuh"
#include "math.h"
namespace
transformer_engine
{
/* \brief Helper class that enables storing multiple values of type DType
...
...
@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename
OutputType
>
void
VectorizedUnaryKernelLauncher
(
const
InputType
*
input
,
const
fp32
*
noop
,
OutputType
*
output
,
const
fp32
*
scale
,
fp32
*
amax
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
params
,
cudaStream_t
stream
)
{
const
Param
&
params
,
cudaStream_t
stream
)
{
if
(
N
!=
0
)
{
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
output
);
...
...
@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename
InputTypeGrad
,
typename
OutputType
>
void
VectorizedUnaryGradKernelLauncher
(
const
InputTypeGrad
*
grad
,
const
InputType
*
input
,
OutputType
*
output
,
const
fp32
*
scale
,
fp32
*
amax
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
params
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
&
params
,
cudaStream_t
stream
)
{
if
(
N
!=
0
)
{
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
grad
,
output
);
...
...
@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
const
ComputeType
val
=
static_cast
<
ComputeType
>
(
loader0
.
separate
()[
i
]);
const
ComputeType
val2
=
static_cast
<
ComputeType
>
(
loader1
.
separate
()[
i
]);
ComputeType
val2
=
static_cast
<
ComputeType
>
(
loader1
.
separate
()[
i
]);
if
constexpr
(
std
::
is_same
<
Param
,
ClampedSwiGLUParam
>::
value
)
{
// Clamp the gated value and add 1 at the end
ComputeType
limit
=
p
.
limit
;
val2
=
std
::
min
(
std
::
max
(
-
limit
,
val2
),
limit
)
+
1
;
}
ComputeType
temp
=
static_cast
<
ComputeType
>
(
Activation
(
val
,
p
)
*
val2
);
if
(
requires_amax
)
{
__builtin_assume
(
max
>=
0
);
...
...
@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
const
ComputeType
grad_val
=
static_cast
<
ComputeType
>
(
grad_loader
.
separate
()[
i
]);
const
ComputeType
gelu_in
=
static_cast
<
ComputeType
>
(
input_loader0
.
separate
()[
i
]);
const
ComputeType
gate_in
=
static_cast
<
ComputeType
>
(
input_loader1
.
separate
()[
i
]);
ComputeType
gate_in
=
static_cast
<
ComputeType
>
(
input_loader1
.
separate
()[
i
]);
bool
dgate_in
=
true
;
if
constexpr
(
std
::
is_same
<
Param
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
const
ComputeType
limit
=
p
.
limit
;
dgate_in
=
gate_in
<=
limit
&&
gate_in
>=
-
limit
;
// Derivative of clamp
gate_in
=
std
::
min
(
std
::
max
(
-
limit
,
gate_in
),
limit
)
+
1.0
f
;
}
ComputeType
after_dgelu
=
Dactivation
(
gelu_in
,
p
)
*
grad_val
*
gate_in
;
ComputeType
after_dgate
=
grad_val
*
Activation
(
gelu_in
,
p
);
ComputeType
after_dgate
=
dgate_in
?
grad_val
*
Activation
(
gelu_in
,
p
)
:
0.0
f
;
if
(
requires_amax
)
{
__builtin_assume
(
max
>=
0
);
...
...
transformer_engine/common/utils.cuh
View file @
063ef88d
...
...
@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
// NOLINT(*)
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
...
...
transformer_engine/debug/features/fake_quant.py
View file @
063ef88d
...
...
@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
from
transformer_engine.pytorch.tensor
import
Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.
fp8
import
_default_sf_compute
from
transformer_engine.pytorch.
quantization
import
_default_sf_compute
def
fake_quantize
(
tensor
:
torch
.
Tensor
,
fp8_format
:
tex
.
DType
,
out
=
None
):
...
...
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
063ef88d
...
...
@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
for
stat
in
config
[
"stats"
]:
self
.
check_if_stat_is_supported
(
stat
,
recipe_name
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
config
.
get
(
"
start_step
"
,
None
)
,
config
.
get
(
"end_step"
,
None
)
,
config
.
get
(
"
start_end_list
"
,
None
)
,
start_step
,
end_step
,
start_end_list
,
"fp8"
,
)
...
...
transformer_engine/debug/features/log_tensor_stats.py
View file @
063ef88d
...
...
@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
transformer_engine.pytorch.tensor.
_internal
.mxfp8_tensor_
bas
e
import
MXFP8Tensor
Bas
e
from
transformer_engine.pytorch.tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
transformer_engine.pytorch.tensor.
storage
.mxfp8_tensor_
storag
e
import
MXFP8Tensor
Storag
e
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils
import
next_enabled_iter
,
get_reduction_params
...
...
@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert
(
type
(
tensor
)
not
in
[
Float8Tensor
,
Float8Tensor
Bas
e
,
MXFP8Tensor
,
MXFP8Tensor
Bas
e
]
type
(
tensor
)
not
in
[
Float8Tensor
,
Float8Tensor
Storag
e
,
MXFP8Tensor
,
MXFP8Tensor
Storag
e
]
and
tensor
.
dtype
!=
torch
.
uint8
),
(
f
"[NVTORCH INSPECT ERROR] Tensor
{
tensor_name
}
must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
config
.
get
(
"
start_step
"
,
None
)
,
config
.
get
(
"end_step"
,
None
)
,
config
.
get
(
"
start_end_list
"
,
None
)
,
start_step
,
end_step
,
start_end_list
,
)
skip_reduction
,
reduction_group
,
reduce_within_microbatch
=
get_reduction_params
(
...
...
transformer_engine/debug/features/utils/stats_buffer.py
View file @
063ef88d
...
...
@@ -172,11 +172,19 @@ class StatsBuffers:
if
self
.
at_least_one_layer_fed
:
return
True
iteration
=
TEDebugState
.
get_iteration
()
for
_
,
next_iter
in
self
.
layers_to_next_iter
.
items
():
layers_to_remove
=
[]
for
layer_name
,
next_iter
in
self
.
layers_to_next_iter
.
items
():
# When next_iter is None the feature will no longer run.
if
next_iter
is
None
:
layers_to_remove
.
append
(
layer_name
)
continue
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
if
iteration
>=
next_iter
:
return
True
for
layer_name
in
layers_to_remove
:
self
.
layers_to_next_iter
.
pop
(
layer_name
,
None
)
return
False
def
reset
(
self
):
...
...
transformer_engine/debug/pytorch/debug_quantization.py
View file @
063ef88d
...
...
@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
prepare_for_saving
,
restore_from_saved
,
)
...
...
@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
self
.
_update_parent_quantizer_usage
()
class
DebugQuantizedTensor
(
QuantizedTensor
Bas
e
):
class
DebugQuantizedTensor
(
QuantizedTensor
Storag
e
):
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
...
...
transformer_engine/jax/__init__.py
View file @
063ef88d
...
...
@@ -34,7 +34,7 @@ load_framework_extension("jax")
from
.
import
flax
from
.
import
quantize
from
.quantize
import
fp8_autocast
,
update_collections
,
get_delayed_scaling
from
.quantize
import
autocast
,
fp8_autocast
,
update_collections
from
.quantize
import
NVTE_FP8_COLLECTION_NAME
from
.sharding
import
MeshResource
...
...
@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum
__all__
=
[
"NVTE_FP8_COLLECTION_NAME"
,
"autocast"
,
"fp8_autocast"
,
"update_collections"
,
"get_delayed_scaling"
,
"MeshResource"
,
"flax"
,
"quantize"
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment