Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3534 additions
and
139 deletions
+3534
-139
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+9
-0
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+6
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+6
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+842
-0
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+94
-55
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+761
-46
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+96
-10
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+10
-0
transformer_engine/common/util/math.h
transformer_engine/common/util/math.h
+35
-4
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+1516
-0
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+79
-3
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+5
-0
transformer_engine/common/util/vectorized_pointwise.h
transformer_engine/common/util/vectorized_pointwise.h
+20
-6
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+20
-0
transformer_engine/debug/features/fake_quant.py
transformer_engine/debug/features/fake_quant.py
+1
-1
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+9
-3
transformer_engine/debug/features/log_tensor_stats.py
transformer_engine/debug/features/log_tensor_stats.py
+12
-6
transformer_engine/debug/features/utils/stats_buffer.py
transformer_engine/debug/features/utils/stats_buffer.py
+9
-1
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+2
-2
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+2
-2
No files found.
transformer_engine/common/transpose/cast_transpose.h
View file @
063ef88d
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
detail
{
namespace
transformer_engine
::
detail
{
...
@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
...
@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const
bool
pow_2_scale
,
const
SimpleTensor
&
noop_tensor
,
const
bool
pow_2_scale
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
}
// namespace transformer_engine::detail
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
063ef88d
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "common/common.h"
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "common/utils.cuh"
...
@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
checkCuDriverContext
(
stream
);
checkCuDriverContext
(
stream
);
if
(
transformer_engine
::
cuda
::
sm_arch
()
>=
100
)
{
NVTE_CHECK
(
pow_2_scale
,
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
,
"with MXFP8, which requires using power of two scaling factors."
);
}
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_rows
=
1
;
size_t
num_rows
=
1
;
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
063ef88d
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "common/common.h"
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise
);
NVTE_API_CALL
(
quantize_transpose_vector_blockwise
);
if
(
transformer_engine
::
cuda
::
sm_arch
()
>=
100
)
{
NVTE_CHECK
(
pow2_scale
,
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
,
"with MXFP8, which requires using power of two scaling factors."
);
}
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
size_t
num_rows
=
1
;
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace
transformer_engine
{
#if CUDA_VERSION >= 12080
namespace
quantize_transpose_nvfp4
{
namespace
{
using
std
::
int32_t
;
using
std
::
uint32_t
;
using
std
::
uint8_t
;
using
transformer_engine
::
detail
::
TypeExtrema
;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr
int
kThreadsPerWarp
=
32
;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr
int
kNFP4PerContainer
=
2
;
// Hyperparameters for performance tuning
constexpr
int
kTileDim
=
128
;
// constexpr int kScaleDim = 32;
constexpr
int
kNVecIn
=
8
;
// The number of elements each LDG touches
constexpr
int
kNVecOut
=
16
;
// The number of elements each STG touches
constexpr
int
kNVecSMem
=
2
;
// The number of elements each LDS/STS touches
constexpr
int
kThreadsPerBlock
=
256
;
// Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert
(
kNVecIn
%
kNVecSMem
==
0
,
"kNVecIn must be divisible by kNVecSMem"
);
static_assert
(
kNVecOut
%
kNVecSMem
==
0
,
"kNVecOut must be divisible by kNVecSMem"
);
constexpr
int
kSMemRow
=
kTileDim
;
constexpr
int
kSMemCol
=
(
kTileDim
/
kNVecSMem
)
+
1
;
constexpr
int
kSMemSize
=
kSMemRow
*
kSMemCol
*
kNVecSMem
;
constexpr
int
kNumThreadsLoad
=
kTileDim
/
kNVecIn
;
// 16
constexpr
int
kNumThreadsStore
=
kTileDim
/
kNVecOut
;
// 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert
(
kNumThreadsLoad
<=
kThreadsPerWarp
,
"kNumThreadsLoad must be <= kThreadsPerWarp"
);
static_assert
(
kNumThreadsStore
<=
kThreadsPerWarp
,
"kNumThreadsStore must be <= kThreadsPerWarp"
);
// for 2D block scaling, we need to reduce amax in warp
static
__device__
constexpr
unsigned
int
WARP_REDUCE_AMAX_GROUP_MASKS
[
8
]
=
{
0x01010101
,
0x02020202
,
0x04040404
,
0x08080808
,
0x10101010
,
0x20202020
,
0x40404040
,
0x80808080
};
// max for every group_size elements in warp
template
<
int
group_size
,
int
shfl_down_stride
>
__device__
__forceinline__
float
groupMax
(
float
val
,
unsigned
int
groupMask
)
{
for
(
int
offset
=
group_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
val
=
max
(
val
,
__shfl_down_sync
(
groupMask
,
val
,
offset
*
shfl_down_stride
));
}
return
val
;
}
template
<
typename
ScaleType
>
__device__
__forceinline__
ScaleType
ComputeDecodeScaleFP4
(
const
float
amax
,
const
float
global_encode_scale
)
{
float
decode_scale
=
amax
/
TypeExtrema
<
fp4e2m1
>::
max
;
decode_scale
=
decode_scale
*
global_encode_scale
;
decode_scale
=
fminf
(
decode_scale
,
TypeExtrema
<
float
>::
max
);
return
static_cast
<
ScaleType
>
(
decode_scale
);
}
template
<
typename
ScaleType
>
__device__
__forceinline__
float
ComputeEncodeScaleFP4
(
ScaleType
decode_scale
,
const
float
global_decode_scale
)
{
return
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
decode_scale
)
*
global_decode_scale
),
TypeExtrema
<
float
>::
max
);
}
template
<
typename
IType
,
typename
ScaleType
>
__device__
__forceinline__
float
ComputeOutputFP4
(
IType
input
,
float
encode_scale
)
{
return
static_cast
<
float
>
(
input
)
*
encode_scale
;
}
__device__
__forceinline__
float
ComputeGlobalEncodeScaleFP4
(
const
float
global_amax
)
{
constexpr
float
fp8_max
=
TypeExtrema
<
fp8e4m3
>::
max
;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
TypeExtrema
<
float
>::
max
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.
f
||
global_encode_scale
==
0.
f
)
{
return
1.
f
;
}
return
global_encode_scale
;
}
__device__
__forceinline__
uint32_t
get_rbits
(
RNG
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
if
(
rnd_idx
==
4
)
{
rnd_idx
=
0
;
curanddx
::
uniform_bits
dist
;
random_uint4
=
dist
.
generate4
(
rng
);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const
uint32_t
*
const
rbits_arr
=
reinterpret_cast
<
uint32_t
*>
(
&
random_uint4
);
const
uint32_t
rbits
=
rbits_arr
[
rnd_idx
++
];
return
rbits
;
}
template
<
class
ScaleType
>
__device__
__forceinline__
size_t
scale_factor_swizzled_offset
(
size_t
row_idx
,
size_t
col_idx
,
uint32_t
col_length
)
{
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr
uint32_t
kTotalRowsPerBaseBlock
=
128
;
constexpr
uint32_t
kRowsPerBaseBlockCol
=
32
;
constexpr
uint32_t
kColsPerBaseBlockCol
=
4
;
const
size_t
rb
=
row_idx
/
kTotalRowsPerBaseBlock
;
const
size_t
rem
=
row_idx
%
kTotalRowsPerBaseBlock
;
const
size_t
d4
=
rem
/
kRowsPerBaseBlockCol
;
const
size_t
d3
=
rem
%
kRowsPerBaseBlockCol
;
const
size_t
cbg
=
col_idx
/
kColsPerBaseBlockCol
;
const
size_t
d5
=
col_idx
%
kColsPerBaseBlockCol
;
const
size_t
cbg_cnt
=
DIVUP
(
col_length
,
kColsPerBaseBlockCol
);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return
((
rb
*
cbg_cnt
+
cbg
)
*
kRowsPerBaseBlockCol
+
d3
)
*
16
+
d4
*
kColsPerBaseBlockCol
+
d5
;
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t
out_4x
;
asm
volatile
(
"{
\n
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5;
\n\t
"
"}"
:
"=h"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
asm
volatile
(
"{
\n
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template
<
bool
kApplyStochasticRounding
>
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
if
constexpr
(
kApplyStochasticRounding
)
{
return
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
in01
,
in23
,
rbits
);
}
else
{
return
cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
rbits
);
}
}
template
<
bool
kReturnIdentity
,
bool
kReturnTranspose
,
bool
kIsE8Scaling
,
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
,
typename
ScaleType
,
bool
kSwizzledScale
,
bool
kApplyStochasticRounding
,
bool
kIs2DBlockScaling
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel
(
const
IType
*
const
input
,
const
float
*
global_amax
,
OType
*
const
output_c
,
OType
*
const
output_t
,
ScaleType
*
const
tile_scales_inv_c
,
ScaleType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
size_t
kScaleBlockDim
,
const
float
epsilon
,
const
size_t
*
rng_state
,
const
float
*
noop_ptr
)
{
constexpr
int
kNVecContainer
=
kNVecOut
/
kNFP4PerContainer
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecContainer
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
const
size_t
block_idx_x
=
blockIdx
.
x
;
const
size_t
block_idx_y
=
blockIdx
.
y
;
const
size_t
rng_sequence
=
threadIdx
.
x
+
block_idx_x
*
kThreadsPerBlock
+
block_idx_y
*
gridDim
.
x
*
kThreadsPerBlock
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
kApplyStochasticRounding
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if
constexpr
(
kIs2DBlockScaling
&&
kIsE8Scaling
)
{
return
;
}
if
constexpr
(
kIs2DBlockScaling
&&
!
kReturnIdentity
)
{
return
;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr
int
kFP4BlockScalingSize
=
16
;
constexpr
int
k2DBlockAmaxDim
=
kIs2DBlockScaling
?
(
kTileDim
/
kFP4BlockScalingSize
)
:
1
;
constexpr
int
kNumRowsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
;
// 4
constexpr
int
k2DBlockAmaxReduceDim
=
kIs2DBlockScaling
?
(
kFP4BlockScalingSize
/
kNumRowsPerWarp
)
:
1
;
__shared__
CType
amax_smem_red
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
][
k2DBlockAmaxReduceDim
];
__shared__
CType
amax_smem
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
];
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
);
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
const
int
kNumThreadsReduce
=
kScaleBlockDim
/
kNVecOut
;
const
float
global_encode_scale
=
kIsE8Scaling
?
1.0
f
:
ComputeGlobalEncodeScaleFP4
(
global_amax
[
0
]);
const
float
global_decode_scale
=
1.0
/
global_encode_scale
;
// Step 2: Cast and store to output_c
if
constexpr
(
kReturnIdentity
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
row_length
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_c
[(
r_g
*
row_length
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if
constexpr
(
kIs2DBlockScaling
)
{
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr
int
kNumRowsPerIter
=
kThreadsPerBlock
/
kNumThreadsStore
;
// 32
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
int
tid_in_warp_x
=
threadIdx
.
x
%
kNumThreadsStore
;
int
tid_in_warp_y
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumRowsPerWarp
;
CType
amax_warp_reduced
=
groupMax
<
kNumRowsPerWarp
,
kNumThreadsStore
>
(
amax
,
WARP_REDUCE_AMAX_GROUP_MASKS
[
tid_in_warp_x
]);
// now T0 ~ T8 in each warp has the reduced amax values
int
data_row_idx
=
iter
*
kNumRowsPerIter
+
warp_idx
*
kNumRowsPerWarp
+
tid_in_warp_y
;
if
(
tid_in_warp_y
==
0
)
{
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
[
warp_idx
%
k2DBlockAmaxReduceDim
]
=
amax_warp_reduced
;
}
__syncthreads
();
if
(
data_row_idx
%
kFP4BlockScalingSize
==
0
)
{
CType
amax_2d
=
0.0
;
for
(
int
i
=
0
;
i
<
k2DBlockAmaxReduceDim
;
i
++
)
{
amax_2d
=
fmaxf
(
amax_2d
,
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
][
i
]);
}
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
=
amax_2d
;
}
__syncthreads
();
// every thread now knows 2D amax
amax
=
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
];
}
// Step 2.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
write_scale_inv
&=
(
c_g
<
row_length
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_y
*
kTileDim
+
r_s
;
size_t
col_idx
=
block_idx_x
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
;
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
row_length
,
kScaleBlockDim
));
tile_scales_inv_c
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
0
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
1
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
0
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
1
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
constexpr
(
kReturnTranspose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
num_rows
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_t
[(
r_g
*
num_rows
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
if
constexpr
(
kIs2DBlockScaling
)
{
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
constexpr
int
kNumColsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
*
kNVecSMem
;
// 8 elements
constexpr
int
kNumWarpsPerBlock
=
kThreadsPerBlock
/
kThreadsPerWarp
;
// 8 warps per block
constexpr
int
kNumColsPerIter
=
kNumColsPerWarp
*
kNumWarpsPerBlock
;
int
tid_in_warp_x
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumColsPerWarp
;
int
tid_in_warp_y
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
%
kNumThreadsStore
;
int
data_col_idx
=
iter
*
kNumColsPerIter
+
warp_idx
*
kNumColsPerWarp
+
tid_in_warp_x
;
amax
=
amax_smem
[
tid_in_warp_y
][
data_col_idx
/
kFP4BlockScalingSize
];
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
}
// Step 3.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// Step 3.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
write_scale_inv
&=
(
c_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
(
block_idx_y
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
);
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
num_rows
,
kScaleBlockDim
));
tile_scales_inv_t
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNFP4PerContainer
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
}
// namespace
}
// namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace
detail
{
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise_fp4
);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK
(
!
pow2_scale
,
"No support for pow2_scale for MXFP4 for now"
);
if
(
!
return_identity
&&
!
return_transpose
)
{
return
;
}
if
(
use_2d_quantization
&&
!
return_identity
)
{
return
;
}
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
num_elements
*=
input
.
shape
.
at
(
i
);
}
// Early return if the input tensor is empty
if
(
num_elements
==
0
)
{
return
;
}
size_t
scale_stride_x
=
0
;
size_t
scale_stride_y
=
0
;
if
(
return_identity
)
{
scale_stride_x
=
1
;
scale_stride_y
=
scale_inv
.
shape
[
1
];
}
size_t
scale_t_stride_x
=
0
;
size_t
scale_t_stride_y
=
0
;
if
(
return_transpose
)
{
scale_t_stride_x
=
1
;
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
}
using
namespace
transformer_engine
::
quantize_transpose_nvfp4
;
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
static_cast
<
size_t
>
(
kTileDim
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
static_cast
<
size_t
>
(
kTileDim
));
// noop tensor for cuda graph
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop_tensor
.
dptr
);
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY
(
output
.
dtype
,
2
,
OutputType
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
using
ScaleType
=
fp8e4m3
;
constexpr
int
kScaleBlockDim
=
16
;
constexpr
bool
kPow2Scale
=
false
;
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity
,
kReturnIdentity
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
kReturnTranspose
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
swizzled_scale
,
kSwizzledScale
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kApplyStochasticRounding
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_2d_quantization
,
kIs2DBlockScaling
,
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
auto
kernel
=
block_scaled_1d_cast_transpose_kernel
<
kReturnIdentity
,
kReturnTranspose
,
kPow2Scale
,
kAligned
,
float
,
InputType
,
OutputType
,
ScaleType
,
kSwizzledScale
,
kApplyStochasticRounding
,
kIs2DBlockScaling
>
;
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
kernel
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
float
*>
(
global_amax
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
kScaleBlockDim
,
epsilon
,
rng_state
,
noop_ptr
);)
// kIs2DBlockScaling
)
// kApplyStochasticRounding
)
// kSwizzledScale
)
// kAligned
)
// kReturnTranspose
)
// kReturnIdentity
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace detail
}
// namespace transformer_engine
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
063ef88d
...
@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
__grid_constant__
CUtensorMap
tensor_map_output_act
,
const
__grid_constant__
CUtensorMap
tensor_map_output_act
,
const
__grid_constant__
CUtensorMap
tensor_map_output_gate
,
const
__grid_constant__
CUtensorMap
tensor_map_output_gate
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
float
*
const
amax_ptr
,
float
*
const
scale_inv_ptr
,
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
)
{
const
float
*
const
scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
ParamOP
p
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const
size_t
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
...
@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType
*
in_gate_sh_curr
=
in_gate_sh
+
buff
*
buff_elems
;
IType
*
in_gate_sh_curr
=
in_gate_sh
+
buff
*
buff_elems
;
OType
*
out_act_sh_curr
=
out_act_sh
+
buff
*
buff_elems
;
OType
*
out_act_sh_curr
=
out_act_sh
+
buff
*
buff_elems
;
OType
*
out_gate_sh_curr
=
out_gate_sh
+
buff
*
buff_elems
;
OType
*
out_gate_sh_curr
=
out_gate_sh
+
buff
*
buff_elems
;
#pragma unroll
#pragma unroll
for
(
int
stage
=
0
;
stage
<
BUFFER_STAGES_NUM
;
++
stage
)
{
for
(
int
stage
=
0
;
stage
<
BUFFER_STAGES_NUM
;
++
stage
)
{
const
size_t
stage_offset_Y
=
stage
*
THREADS_PER_CHUNK_Y
;
const
size_t
stage_offset_Y
=
stage
*
THREADS_PER_CHUNK_Y
;
...
@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
act_elt
=
static_cast
<
float
>
(
in_act_sh_curr
[
shmem_idx
]);
float
act_elt
=
static_cast
<
float
>
(
in_act_sh_curr
[
shmem_idx
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh_curr
[
shmem_idx
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh_curr
[
shmem_idx
]);
bool
dgate_elt
=
true
;
// gating is ideally an identity function
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1
;
}
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh_curr
[
shmem_idx
]);
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh_curr
[
shmem_idx
]);
...
@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
float
x
=
act_elt
;
const
float
x
=
act_elt
;
float
act_x
;
float
act_x
;
float
dact_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
if
const
expr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
x
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
if
(
act_elt
<=
p
.
limit
)
{
dact_x
=
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
;
}
else
{
dact_x
=
0.0
f
;
}
}
else
{
}
else
{
act_x
=
ActOP
(
x
,
{});
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
dact_x
=
DActOP
(
x
,
{});
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
}
float
after_dact
=
dact_x
*
grad_elt
*
gate_elt
;
float
after_dact
=
dact_x
*
grad_elt
*
gate_elt
;
float
after_dgate
=
act_x
*
grad_elt
;
float
after_dgate
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dact
);
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dact
);
out_gate_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dgate
);
out_gate_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_dgate
);
...
@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax
=
fmaxf
(
amax
,
fabsf
(
after_dact
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_dact
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_dgate
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_dgate
));
}
else
{
}
else
{
const
float
after_act
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
const
float
after_act
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_act
);
out_act_sh_curr
[
shmem_idx
]
=
static_cast
<
OType
>
(
scale
*
after_act
);
amax
=
fmaxf
(
amax
,
fabsf
(
after_act
));
amax
=
fmaxf
(
amax
,
fabsf
(
after_act
));
}
}
...
@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
__grid_constant__
CUtensorMap
tensor_map_output_gate_colwise
,
const
__grid_constant__
CUtensorMap
tensor_map_output_gate_colwise
,
e8m0_t
*
const
scales_rowwise
,
e8m0_t
*
const
scales_colwise
,
e8m0_t
*
const
scales_rowwise
,
e8m0_t
*
const
scales_colwise
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
scale_stride_colwise
)
{
const
size_t
scale_stride_colwise
,
const
ParamOP
p
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
OType2
=
typename
ptx
::
FPx2
<
OType
>
;
using
OType2
=
typename
ptx
::
FPx2
<
OType
>
;
...
@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh
[
shmem_offset_colwise
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate_sh
[
shmem_offset_colwise
]);
float
after_act_elt
;
float
after_act_elt
;
float
after_gate_elt
;
float
after_gate_elt
;
bool
dgate_elt
=
true
;
// gating is ideally an identity function
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1.0
f
;
}
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh
[
shmem_offset_colwise
]);
float
grad_elt
=
static_cast
<
float
>
(
in_grad_sh
[
shmem_offset_colwise
]);
const
float
x
=
act_elt
;
const
float
x
=
act_elt
;
float
act_x
;
float
act_x
;
float
dact_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
if
const
expr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
x
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
dact_x
=
act_elt
<=
p
.
limit
?
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
:
0.0
f
;
}
else
{
}
else
{
act_x
=
ActOP
(
x
,
{});
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
dact_x
=
DActOP
(
x
,
{});
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
}
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_gate_elt
=
act_x
*
grad_elt
;
after_gate_elt
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
}
else
{
}
else
{
after_act_elt
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
after_act_elt
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
}
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
...
@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
const
e8m0_t
biased_exponent_gate
=
const
e8m0_t
biased_exponent_gate
=
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
...
@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float
gate_elt
=
static_cast
<
float
>
(
in_gate
.
data
.
elt
[
e
]);
float
gate_elt
=
static_cast
<
float
>
(
in_gate
.
data
.
elt
[
e
]);
float
after_act_elt
;
float
after_act_elt
;
float
after_gate_elt
;
float
after_gate_elt
;
bool
dgate_elt
=
true
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
dgate_elt
=
gate_elt
<=
p
.
limit
&&
gate_elt
>=
-
p
.
limit
;
// Derivative of clamp
gate_elt
=
min
(
max
(
-
p
.
limit
,
gate_elt
),
p
.
limit
)
+
1.0
f
;
}
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
float
grad_elt
=
static_cast
<
float
>
(
in_grad
.
data
.
elt
[
e
]);
float
grad_elt
=
static_cast
<
float
>
(
in_grad
.
data
.
elt
[
e
]);
const
float
x
=
act_elt
;
const
float
x
=
act_elt
;
float
act_x
;
float
act_x
;
float
dact_x
;
float
dact_x
;
if
constexpr
(
std
::
is_same
<
ParamOP
,
ClampedSwiGLUParam
>::
value
)
{
if
const
expr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
const
float
x
=
min
(
act_elt
,
p
.
limit
);
const
float
s
=
sigmoidf
(
x
);
const
float
s
=
sigmoidf
(
p
.
alpha
*
x
);
act_x
=
x
*
s
;
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
dact_x
=
act_elt
<=
p
.
limit
?
s
+
s
*
(
1
-
s
)
*
p
.
alpha
*
x
:
0.0
f
;
}
else
{
}
else
{
act_x
=
ActOP
(
x
,
{});
if
constexpr
((
ActOP
==
&
silu
<
fp32
,
fp32
>
)
&&
(
DActOP
==
&
dsilu
<
fp32
,
fp32
>
))
{
dact_x
=
DActOP
(
x
,
{});
const
float
s
=
sigmoidf
(
x
);
act_x
=
x
*
s
;
dact_x
=
x
*
s
*
(
1
-
s
)
+
s
;
}
else
{
act_x
=
ActOP
(
x
,
p
);
dact_x
=
DActOP
(
x
,
p
);
}
}
}
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_act_elt
=
dact_x
*
grad_elt
*
gate_elt
;
after_gate_elt
=
act_x
*
grad_elt
;
after_gate_elt
=
dgate_elt
?
act_x
*
grad_elt
:
0.0
f
;
after_act_rowwise
[
j
]
=
after_act_elt
;
after_act_rowwise
[
j
]
=
after_act_elt
;
after_gate_rowwise
[
j
]
=
after_gate_elt
;
after_gate_rowwise
[
j
]
=
after_gate_elt
;
}
else
{
}
else
{
after_act_elt
=
ActOP
(
act_elt
,
{}
)
*
gate_elt
;
after_act_elt
=
ActOP
(
act_elt
,
p
)
*
gate_elt
;
after_act_rowwise
[
j
]
=
after_act_elt
;
after_act_rowwise
[
j
]
=
after_act_elt
;
}
}
...
@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx
::
mul_cvt_2x
(
out_gate_pair
,
in_gate
,
block_scale_inverse_2x_gate
);
ptx
::
mul_cvt_2x
(
out_gate_pair
,
in_gate
,
block_scale_inverse_2x_gate
);
}
}
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
...
@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_fp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
assert
(
false
);
...
@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
...
@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
buff_size_aligned_out
;
const
size_t
shmem_size
=
grad_mem
+
(
in_act_mem
+
in_gate_mem
)
+
const
size_t
shmem_size
=
grad_mem
+
(
in_act_mem
+
in_gate_mem
)
+
(
out_act_mem
+
out_gate_mem
)
+
TMA_SHMEM_ALIGNMENT
;
(
out_act_mem
+
out_gate_mem
)
+
TMA_SHMEM_ALIGNMENT
;
...
@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
...
@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
>
cast_fp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
>
<<<
grid_dim
,
block_dim
,
shmem_size
,
stream
>>>
(
<<<
grid_dim
,
block_dim
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act
,
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act
,
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
tensor_map_output_gate
,
amax_ptr
,
scale_inv_ptr
,
scale_ptr
,
rows
,
cols
,
p
);
cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
#endif
#endif
...
@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
...
@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
cast_mxfp8_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
assert
(
false
);
...
@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
break
;
case
ScalingType
::
COLWISE
:
case
ScalingType
::
COLWISE
:
...
@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
break
;
case
ScalingType
::
BIDIMENSIONAL
:
case
ScalingType
::
BIDIMENSIONAL
:
...
@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType
,
true
,
true
,
OType
,
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
,
THREADS_PER_CHUNK_NON_COLWISE
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
mxfp8_kernel
::
cast_mxfp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
mxfp8_kernel
::
cast_mxfp8_gated_kernel
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
...
@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
scale_stride_colwise
,
p
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
break
;
});
// NOLINT(*)
});
// NOLINT(*)
...
@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
}
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
)>
void
cast_gated
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_gated
(
const
Tensor
&
input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
CheckInputTensor
(
input
,
"gated_act_input"
);
CheckInputTensor
(
input
,
"gated_act_input"
);
CheckOutputTensor
(
*
output
,
"gated_act_output"
);
CheckOutputTensor
(
*
output
,
"gated_act_output"
);
NVTE_CHECK
(
output
->
flat_first_dim
()
==
input
.
flat_first_dim
(),
"Wrong output shape. Expected (after flattening) ["
,
input
.
flat_first_dim
(),
", *], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
input
.
flat_last_dim
()
%
2
==
0
,
NVTE_CHECK
(
input
.
flat_last_dim
()
%
2
==
0
,
"Wrong input shape. Expected (after flattening) last dimension to be even, "
,
"got ["
,
"Wrong input shape. Expected (after flattening) last dimension to be even, "
,
"got ["
,
input
.
flat_first_dim
(),
", "
,
input
.
flat_last_dim
(),
"]."
);
input
.
flat_first_dim
(),
", "
,
input
.
flat_last_dim
(),
"]."
);
...
@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
...
@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
input
.
flat_first_dim
(),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
input
.
flat_first_dim
(),
output
->
flat_last_dim
(),
{}
,
stream
);
output
->
flat_last_dim
(),
p
,
stream
);
}
else
{
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
});
// NOLINT(*)
...
@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
...
@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
cast_dgated
(
const
Tensor
&
grad
,
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_dgated
(
const
Tensor
&
grad
,
const
Tensor
&
input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
CheckInputTensor
(
grad
,
"dgated_act_grad"
);
CheckInputTensor
(
grad
,
"dgated_act_grad"
);
CheckInputTensor
(
input
,
"dgated_act_input"
);
CheckInputTensor
(
input
,
"dgated_act_input"
);
CheckOutputTensor
(
*
output
,
"dgated_act_output"
);
CheckOutputTensor
(
*
output
,
"dgated_act_output"
);
...
@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
...
@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
grad
.
flat_first_dim
(),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
grad
.
flat_first_dim
(),
grad
.
flat_last_dim
(),
{}
,
stream
);
grad
.
flat_last_dim
(),
p
,
stream
);
}
else
{
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
});
// NOLINT(*)
...
@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
...
@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
void
quantize_gated
(
const
Tensor
&
grad
,
const
Tensor
&
gated_input
,
Tensor
*
output
,
ParamOP
p
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
constexpr
bool
allow_empty
=
false
;
constexpr
bool
allow_empty
=
false
;
CheckInputTensor
(
gated_input
,
"gated_input"
);
CheckInputTensor
(
gated_input
,
"gated_input"
);
...
@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
...
@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if
(
is_delayed_tensor_scaling
(
output
->
scaling_mode
))
{
if
(
is_delayed_tensor_scaling
(
output
->
scaling_mode
))
{
if
(
use_tma_kernels
)
{
if
(
use_tma_kernels
)
{
cast_fp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_fp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
}
else
{
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
}
else
{
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input
,
output
,
stream
);
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input
,
output
,
p
,
stream
);
}
}
}
}
}
else
if
(
is_mxfp_scaling
(
output
->
scaling_mode
))
{
}
else
if
(
is_mxfp
8
_scaling
(
output
->
scaling_mode
))
{
if
(
use_tma_kernels
)
{
if
(
use_tma_kernels
)
{
cast_mxfp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
cast_mxfp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
p
,
stream
);
}
else
{
}
else
{
NVTE_ERROR
(
"Invalid input shape. Expected the last dimension to be divisible "
,
NVTE_ERROR
(
"Invalid input shape. Expected the last dimension to be divisible "
,
"by 32, got input of shape "
,
gated_input
.
data
.
shape
);
"by 32, got input of shape "
,
gated_input
.
data
.
shape
);
...
@@ -1292,7 +1330,7 @@ namespace detail {
...
@@ -1292,7 +1330,7 @@ namespace detail {
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
template
<
bool
IS_DGATED
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
float
(
*
DActOP
)(
float
,
const
ParamOP
&
)>
void
quantize_gated_helper
(
const
NVTETensor
grad
,
const
NVTETensor
gated_input
,
NVTETensor
output
,
void
quantize_gated_helper
(
const
NVTETensor
grad
,
const
NVTETensor
gated_input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
ParamOP
p
,
cudaStream_t
stream
)
{
using
namespace
gated_kernels
;
using
namespace
gated_kernels
;
Tensor
grad_empty_tensor
;
Tensor
grad_empty_tensor
;
const
Tensor
&
grad_tensor
=
IS_DGATED
?
*
(
convertNVTETensorCheck
(
grad
))
:
grad_empty_tensor
;
const
Tensor
&
grad_tensor
=
IS_DGATED
?
*
(
convertNVTETensorCheck
(
grad
))
:
grad_empty_tensor
;
...
@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
...
@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if
(
is_supported_by_CC_100
())
{
if
(
is_supported_by_CC_100
())
{
quantize_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
quantize_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
stream
);
output_tensor
,
p
,
stream
);
}
else
{
}
else
{
if
(
is_delayed_tensor_scaling
(
output_tensor
->
scaling_mode
))
{
if
(
is_delayed_tensor_scaling
(
output_tensor
->
scaling_mode
))
{
if
constexpr
(
IS_DGATED
)
{
if
constexpr
(
IS_DGATED
)
{
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
stream
);
cast_dgated
<
ParamOP
,
ActOP
,
DActOP
>
(
grad_tensor
,
gated_input_tensor
,
output_tensor
,
p
,
stream
);
}
else
{
}
else
{
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input_tensor
,
output_tensor
,
stream
);
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input_tensor
,
output_tensor
,
p
,
stream
);
}
}
}
else
{
}
else
{
// MX scaling
// MX scaling
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
063ef88d
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "../utils.cuh"
#include "math.h"
#include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
...
@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
size_t
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
size_t
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
size_t
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
// helps resolving bank conflicts in shmem
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
...
@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
act_in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
elt_input_mem
);
IType
*
act_in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
elt_input_mem
);
OType
*
out_rowwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
OType
*
out_rowwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
...
@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
size_t
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
const
size_t
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
out_colwise_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
out_colwise_
data_
sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
}
}
...
@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
int
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
size_t
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
int
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
const
int
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
scales_rowwise
[
scale_idx
]
=
biased_exponent
;
if
(
rowwise_scale_is_within_bounds
)
{
scales_rowwise
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
...
@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
out
.
store_to
(
&
out_rowwise_sh
[
shmem_offset_rowwise
]);
out
.
store_to
(
&
out_rowwise_
data_
sh
[
shmem_offset_rowwise
]);
}
}
}
}
...
@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
if
(
is_master_thread
)
{
const
size_
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
in
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_
t
global_offset_X
=
block_offset_X
;
const
in
t
global_offset_X
=
block_offset_X
;
const
size_
t
buff_offset
=
buff
*
BUFF_DIM
;
const
in
t
buff_offset
=
buff
*
BUFF_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_
data_
sh
[
buff_offset
]));
}
}
if
constexpr
(
COLWISE_SCALING
)
{
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_
data_
sh
[
buff_offset
]));
}
}
// Create a "bulk async-group" out of the previous bulk copy operation.
// Create a "bulk async-group" out of the previous bulk copy operation.
...
@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
// Added extra 1-element padding per thread_X to reduce bank conflicts
float
*
partial_dbias_rowwise
=
reinterpret_cast
<
float
*>
(
dshmem
);
float
*
partial_dbias_rowwise
=
reinterpret_cast
<
float
*>
(
dshmem
);
constexpr
size_
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
constexpr
in
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
const
size_
t
shmem_thread_offset
=
const
in
t
shmem_thread_offset
=
tid_Y_rowwise
*
DBIAS_BUFF_WIDTH
+
tid_X_rowwise
*
(
SCALE_DIM_X
+
1
);
tid_Y_rowwise
*
DBIAS_BUFF_WIDTH
+
tid_X_rowwise
*
(
SCALE_DIM_X
+
1
);
#pragma unroll
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
in
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
const
in
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
#pragma unroll
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
const
int
j
=
w
*
PACK_SIZE
+
e
;
const
size_
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
const
in
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
partial_dbias_rowwise
[
shmem_elt_idx
]
=
thread_dbias_rowwise
[
j
];
partial_dbias_rowwise
[
shmem_elt_idx
]
=
thread_dbias_rowwise
[
j
];
}
}
}
}
...
@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
THREADS_Y
;
++
i
)
{
for
(
int
i
=
0
;
i
<
THREADS_Y
;
++
i
)
{
// Add extra element offset per MXFP8 scaling block [1x32]
// Add extra element offset per MXFP8 scaling block [1x32]
const
size_
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
const
in
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
thread_partial_dbias
+=
thread_partial_dbias
+=
partial_dbias_rowwise
[
i
*
DBIAS_BUFF_WIDTH
+
threadIdx
.
x
+
scaling_block
];
partial_dbias_rowwise
[
i
*
DBIAS_BUFF_WIDTH
+
threadIdx
.
x
+
scaling_block
];
}
}
}
}
const
size_
t
dbias_stride
=
cols
;
const
in
t
dbias_stride
=
cols
;
const
size_
t
dbias_offset_Y
=
blockIdx
.
y
;
const
in
t
dbias_offset_Y
=
blockIdx
.
y
;
const
size_
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
in
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
size_
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
in
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
bool
col_out_of_bounds_dbias
=
(
dbias_offset_X
>=
cols
);
const
bool
col_out_of_bounds_dbias
=
(
dbias_offset_X
>=
cols
);
if
(
!
col_out_of_bounds_dbias
)
{
if
(
!
col_out_of_bounds_dbias
)
{
dbias_workspace
[
dbias_idx
]
=
thread_partial_dbias
;
dbias_workspace
[
dbias_idx
]
=
thread_partial_dbias
;
...
@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
#endif // __HIP_PLATFORM_AMD__
}
// namespace mxfp8_kernel
}
// namespace mxfp8_kernel
namespace
nvfp4_kernel
{
using
namespace
ptx
;
constexpr
size_t
SCALE_DIM_Y
=
32
;
constexpr
size_t
SCALE_DIM_X
=
16
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
BUFF_DIM_Y
=
32
;
constexpr
size_t
PACK_SIZE
=
8
;
constexpr
size_t
WAVES
=
SCALE_DIM_X
/
PACK_SIZE
;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__
__forceinline__
fp8e4m3
compute_decoding_scaling_factor
(
const
float
block_amax
,
const
float
S_enc
)
{
constexpr
float
rcp_6f
=
1.0
f
/
6.0
f
;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return
static_cast
<
fp8e4m3
>
(
block_amax
*
rcp_6f
*
S_enc
);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
cast_nvfp4_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output_rowwise
,
const
__grid_constant__
CUtensorMap
tensor_map_output_colwise
,
fp8e4m3
*
const
scales_rowwise_e4m3
,
e8m0_t
*
const
scales_colwise_e8m0
,
const
float
*
noop
,
float
*
const
amax_ptr
,
const
float
*
const
nvfp4_second_stage_scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
scale_stride_colwise
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
ROWWISE_SCALING
=
true
;
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
constexpr
size_t
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
constexpr
size_t
THREADS_X_ROWWISE
=
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
constexpr
size_t
THREADS_Y_ROWWISE
=
THREADS_PER_CHUNK
/
THREADS_X_ROWWISE
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM_Y
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
constexpr
size_t
BUFF_IN_DIM_X
=
CHUNK_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM_X
=
(
CHUNK_DIM_X
*
4
)
/
8
;
// Holds 2 elements of 4-bit size
constexpr
size_t
BUFF_IN_DIM
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM
=
BUFF_DIM_Y
*
BUFF_OUT_DIM_X
;
constexpr
size_t
STAGES
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
ITERATIONS_ROWWISE
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
&&
ROWWISE_SCALING
&&
COLWISE_SCALING
;
const
int
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
CHUNK_DIM_X
/
SCALE_DIM_X
;
const
int
scales_block_offset_Y_colwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
/
SCALE_DIM_Y
;
const
int
scales_block_offset_X_colwise
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
int
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
int
tid_Y_colwise
=
0
;
const
int
tid_X_colwise
=
threadIdx
.
x
;
const
int
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
int
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM_X
;
const
int
thread_offset_Y_colwise
=
tid_Y_colwise
;
const
int
thread_offset_X_colwise
=
tid_X_colwise
;
// Each thread processes two adjacent elements
const
int
row_base_rowwise
=
block_offset_Y
+
thread_offset_Y_rowwise
;
const
int
row_base_colwise
=
block_offset_Y
+
thread_offset_Y_colwise
;
const
int
col_base_colwise
=
block_offset_X
+
thread_offset_X_colwise
;
const
bool
col_out_of_bounds_colwise
=
(
col_base_colwise
>=
cols
);
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
int
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
int
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
const
bool
colwise_scale_is_within_bounds
=
scales_offset_X_colwise
<
cols
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
CHUNK_DIM_Y
*
(
CHUNK_DIM_X
/
SCALE_DIM_X
)
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
/
SCALE_DIM_Y
)
*
CHUNK_DIM_X
*
sizeof
(
fp8e8m0
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
(
ROWWISE_SCALING
?
buff_size_aligned_out_nvfp4
:
0
);
constexpr
size_t
out_mem_colwise_data
=
(
COLWISE_SCALING
?
buff_size_aligned_out_mxfp8
:
0
);
constexpr
size_t
out_mem_rowwise_scales
=
(
ROWWISE_SCALING
?
buff_size_nvfp4_scales
:
0
);
constexpr
size_t
out_mem_colwise_scales
=
(
COLWISE_SCALING
?
buff_size_mxfp8_scales
:
0
);
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_rowwise_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
fp8e4m3
*
out_rowwise_scales_sh
=
reinterpret_cast
<
fp8e4m3
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
e8m0_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
e8m0_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
int
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const
float
S_enc
=
(
nvfp4_second_stage_scale_ptr
==
nullptr
)
?
1.0
f
:
1.0
f
/
(
*
nvfp4_second_stage_scale_ptr
);
float
thread_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_PER_CHUNK
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
int
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
int
buff
=
stage
%
BUFFS_NUM
;
const
int
next_stage
=
stage
+
1
;
const
int
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
int
buff_offset_in
=
buff
*
BUFF_IN_DIM
;
const
int
buff_offset_out
=
buff
*
BUFF_OUT_DIM
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
int
next_buff
=
next_stage
%
BUFFS_NUM
;
const
int
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
int
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
next_buff_offset
=
next_buff
*
BUFF_IN_DIM
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
const
int
shmem_offset_base_colwise
=
buff_offset_in
+
tid_X_colwise
;
block_amax
=
0.0
f
;
float
in_compute_colwise
[
SCALE_DIM_Y
];
IType
in_colwise_IType
[
SCALE_DIM_Y
];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType
block_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
block_amax_f16
=
__hmax
(
block_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
block_amax
=
static_cast
<
float
>
(
block_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
+
i
>=
rows
);
const
bool
out_of_bounds
=
(
col_out_of_bounds_colwise
||
row_out_of_bounds_colwise
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
block_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
int
global_scales_offset_Y
=
scales_offset_Y_colwise
+
stage
;
const
int
global_scales_offset_X
=
scales_offset_X_colwise
;
const
int
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
if
(
colwise_scale_is_within_bounds
)
{
scales_colwise_e8m0
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
// 3. Scale elements
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
float
in
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in
=
static_cast
<
float
>
(
in_colwise_IType
[
i
]);
}
else
{
in
=
in_compute_colwise
[
i
];
}
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
int
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
out_colwise_data_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
if
constexpr
(
ROWWISE_SCALING
)
{
const
int
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS_ROWWISE
;
++
it
)
{
const
int
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
int
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
int
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
const
int
it_offset_Y
=
stage_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
block_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM_X
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
(
!
out_of_bounds
)
{
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
fp8e4m3
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc
);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if
(
rowwise_scale_is_within_bounds
)
{
const
int
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
const
int
scales_offset_X
=
scales_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y
*
scale_stride_rowwise
+
scales_offset_X
;
scales_rowwise_e4m3
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
#else
const
int
shmem_scales_offset_Y
=
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
;
const
int
shmem_scales_offset_X
=
tid_X_rowwise
;
const
int
scale_idx
=
shmem_scales_offset_Y
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
+
shmem_scales_offset_X
;
out_rowwise_scales_sh
[
scale_idx
]
=
S_dec_b_fp8
;
#endif
// Compute "correct" per-block encoding scaling factor
const
float
block_scale_inverse
=
__fdiv_rn
(
S_enc
,
static_cast
<
float
>
(
S_dec_b_fp8
));
// S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
// Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in01
=
in_IType
[
w
].
data
.
elt
[
2
*
e
];
in23
=
in_IType
[
w
].
data
.
elt
[
2
*
e
+
1
];
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
in01
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
];
in01
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
1
];
in23
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
2
];
in23
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
3
];
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
in01
.
x
=
in_compute_rowwise
[
j
];
in01
.
y
=
in_compute_rowwise
[
j
+
1
];
in23
.
x
=
in_compute_rowwise
[
j
+
2
];
in23
.
y
=
in_compute_rowwise
[
j
+
3
];
}
fp4e2m1x4
&
out_quad
=
reinterpret_cast
<
fp4e2m1x4
&>
(
out
.
data
.
elt
[
e
]);
ptx
::
mul_cvt_4x
(
out_quad
,
in01
,
in23
,
block_scale_inverse
);
}
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_rowwise_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
__builtin_assume
(
block_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
buff_offset_nvfp4
=
buff
*
BUFF_OUT_DIM
;
const
int
buff_offset_mxfp8
=
buff
*
BUFF_IN_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_data_sh
[
buff_offset_nvfp4
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_data_sh
[
buff_offset_mxfp8
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if
constexpr
(
ROWWISE_SCALING
)
{
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
threadIdx
.
x
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y_rowwise
*
scale_stride_rowwise
+
scales_offset_X_rowwise
;
const
int
scale_idx_shmem
=
threadIdx
.
x
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
if
((
threadIdx
.
x
<
CHUNK_DIM_Y
)
&&
(
scales_offset_Y_rowwise
<
rows
)
&&
(
scales_offset_X_rowwise
<
(
cols
/
SCALE_DIM_X
)))
{
using
ScalesVec_t
=
Vec
<
fp8e4m3
,
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
>
;
const
ScalesVec_t
&
scales
=
*
reinterpret_cast
<
ScalesVec_t
*>
(
&
out_rowwise_scales_sh
[
scale_idx_shmem
]);
scales
.
store_to
(
&
scales_rowwise_e4m3
[
scale_idx_global
]);
}
}
#endif
float
chunk_amax
=
0.0
f
;
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
chunk_amax
=
reduce_max
<
THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
thread_amax
,
warp_id
);
}
if
(
is_master_thread
&&
amax_ptr
!=
nullptr
)
{
atomicMaxFloat
(
amax_ptr
,
chunk_amax
);
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_kernel
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_THREADS_PER_CHUNK
=
128
;
constexpr
size_t
FP8_THREADS_PER_CHUNK
=
128
;
...
@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
...
@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
}
}
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
static
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
bool
isFullTile
=
(
N
%
ELEMS_PER_BLOCK
==
0
);
const
bool
isFullTile
=
(
N
%
ELEMS_PER_BLOCK
==
0
);
...
@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
...
@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif
#endif
}
}
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
nvfp4_quantize
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
using
namespace
nvfp4_kernel
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 Output tensor must be allocated."
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
if
(
use_colwise_scaling
)
{
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Columnwise scaling tensor must be allocated"
);
}
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_PER_CHUNK
=
128
;
constexpr
size_t
BUFF_DIM_X
=
CHUNK_DIM_X
;
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
size_t
block_size
=
THREADS_PER_CHUNK
;
const
size_t
scale_stride_rowwise
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_colwise
=
use_colwise_scaling
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
1
;
fp8e4m3
*
const
scales_rowwise_e4m3_ptr
=
reinterpret_cast
<
fp8e4m3
*>
(
output
->
scale_inv
.
dptr
);
e8m0_t
*
const
scales_colwise_e8m0_ptr
=
use_colwise_scaling
?
reinterpret_cast
<
e8m0_t
*>
(
output
->
columnwise_scale_inv
.
dptr
)
:
nullptr
;
const
ScalingType
scaling_type
=
use_colwise_scaling
?
ScalingType
::
BIDIMENSIONAL
:
ScalingType
::
ROWWISE
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
nvfp4_second_stage_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
scale
.
dptr
);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const
DType
output_data_type
=
use_colwise_scaling
?
output
->
columnwise_data
.
dtype
:
DType
::
kFloat8E4M3
;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output_data_type
,
OType
,
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_colwise
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output_rowwise
,
output
->
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
use_colwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
OType
)
*
8
);
}
constexpr
size_t
buff_elems
=
nvfp4_kernel
::
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
nvfp4_kernel
::
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
16
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
32
*
sizeof
(
e8m0_t
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
const
size_t
out_rowwise_data_mem
=
buff_size_aligned_out_nvfp4
;
const
size_t
out_colwise_data_mem
=
use_colwise_scaling
?
buff_size_aligned_out_mxfp8
:
0
;
const
size_t
out_rowwise_scales_mem
=
buff_size_nvfp4_scales
;
const
size_t
out_colwise_scales_mem
=
use_colwise_scaling
?
buff_size_mxfp8_scales
:
0
;
const
size_t
out_mem
=
out_rowwise_data_mem
+
out_colwise_data_mem
+
out_rowwise_scales_mem
+
out_colwise_scales_mem
+
TMA_SHMEM_ALIGNMENT
;
const
size_t
dshmem_size
=
in_mem
+
out_mem
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
case
ScalingType
::
BIDIMENSIONAL
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
});
// NOLINT(*)
);
// NOLINT(*)
}
namespace
detail
{
namespace
detail
{
using
Empty
=
transformer_engine
::
Empty
;
using
Empty
=
transformer_engine
::
Empty
;
...
@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
...
@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto
dbias_tensor
=
convertNVTETensor
(
dbias
);
auto
dbias_tensor
=
convertNVTETensor
(
dbias
);
auto
workspace_tensor
=
convertNVTETensor
(
workspace
);
auto
workspace_tensor
=
convertNVTETensor
(
workspace
);
const
QuantizationConfig
*
quant_config_cpp
=
// Quantization config
reinterpret_cast
<
const
QuantizationConfig
*>
(
quant_config
);
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// extract noop tensor from quant_config_cpp if it's not null
// Noop flag
const
NVTETensor
noop
=
quant_config_cpp
?
quant_config_cpp
->
noop_tensor
:
nullptr
;
Tensor
dummy_tensor
;
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
convertNVTETensorCheck
(
noop
))
:
Tensor
();
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensor
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
// Dispatch to quantization kernel depending on data format
switch
(
output_tensor
->
scaling_mode
)
{
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
output_tensor
->
has_columnwise_data
())
{
if
(
output_tensor
->
has_columnwise_data
())
{
...
@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
...
@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK
(
output_tensor
->
has_data
(),
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
)
{
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
)
{
cast_transpose
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
stream
);
cast_transpose
(
*
input_tensor
,
*
noop_tensor
,
output_tensor
,
stream
);
}
else
{
}
else
{
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
float
,
ParamOP
,
OP
>
(
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
float
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
*
input_tensor
,
activation_input_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
...
@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
...
@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
}
}
}
else
if
(
output_tensor
->
has_data
())
{
}
else
if
(
output_tensor
->
has_data
())
{
fp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
fp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
workspace_tensor
,
stream
);
}
}
break
;
break
;
}
}
case
NVTE_MXFP8_1D_SCALING
:
{
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
mxfp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
workspace_tensor
,
stream
);
break
;
break
;
}
}
case
NVTE_NVFP4_1D_SCALING
:
{
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
input_tensor
,
"input"
);
CheckOutputTensor
(
*
output_tensor
,
"output"
,
false
);
// Choose kernel
int32_t
rows
=
input_tensor
->
flat_first_dim
();
int32_t
cols
=
input_tensor
->
flat_last_dim
();
auto
dtype
=
input_tensor
->
dtype
();
bool
use_optimized_kernel
=
dtype
==
DType
::
kBFloat16
&&
rows
%
32
==
0
&&
cols
%
32
==
0
&&
output_tensor
->
has_data
();
// Launch NVFP4 quantize kernel
if
(
use_optimized_kernel
)
{
if
(
quant_config_cpp
.
nvfp4_2d_quantization
)
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
true
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
else
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
false
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
}
else
{
auto
&
global_amax
=
(
output_tensor
->
amax
.
dptr
!=
nullptr
)
?
output_tensor
->
amax
:
output_tensor
->
columnwise_amax
;
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization"
);
quantize_transpose_vector_blockwise_fp4
(
/*input=*/
input_tensor
->
data
,
/*global_amax=*/
global_amax
,
/*scale_inv=*/
output_tensor
->
scale_inv
,
/*scale_inv_t=*/
output_tensor
->
columnwise_scale_inv
,
/*output=*/
output_tensor
->
data
,
/*output_t=*/
output_tensor
->
columnwise_data
,
/*epsilon=*/
0.0
f
,
/*return_identity=*/
output_tensor
->
has_data
(),
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
/*pow2_scale=*/
false
,
/*swizzled_scale=*/
false
,
/*use_stochastic_rounding=*/
quant_config_cpp
.
stochastic_rounding
,
/*rng_state=*/
quant_config_cpp
.
rng_state
,
/*use_2d_quantization=*/
quant_config_cpp
.
nvfp4_2d_quantization
,
/*noop_tensor=*/
noop_tensor
->
data
,
/*stream=*/
stream
);
}
break
;
}
case
NVTE_BLOCK_SCALING_2D
:
{
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"
);
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
true
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
quantize_transpose_square_blockwise
(
quantize_transpose_square_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*noop_tensor=*/
noop_tensor
.
data
,
stream
);
/*noop_tensor=*/
noop_tensor
->
data
,
stream
);
break
;
break
;
}
}
case
NVTE_BLOCK_SCALING_1D
:
{
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
false
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
quant_config_cpp
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
}
if
(
output_tensor
->
has_columnwise_data
())
{
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
quant_config_cpp
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
columnwise_option
=
columnwise_compact
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
...
@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
...
@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise
(
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
noop_tensor
.
data
,
stream
);
columnwise_option
,
force_pow_2_scales
,
noop_tensor
->
data
,
stream
);
break
;
break
;
}
}
default:
default:
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
063ef88d
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <limits>
#include "../common.h"
#include "../common.h"
...
@@ -28,6 +30,7 @@
...
@@ -28,6 +30,7 @@
#include "math.h"
#include "math.h"
#include "ptx.cuh"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "transformer_engine/transpose.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
...
@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#endif
#endif
}
}
#if CUDA_VERSION >= 12080
template
<
typename
OType
>
__global__
void
__launch_bounds__
(
512
)
dequantize_fp4_kernel
(
const
void
*
const
input
,
OType
*
output
,
const
fp8e4m3
*
const
scales
,
const
float
*
const
tensor_amax
,
const
size_t
N
,
const
size_t
M
,
const
size_t
scale_stride
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
x
=
thread_idx
%
M
;
const
size_t
y
=
thread_idx
/
M
;
union
fp4vec
{
uint64_t
vec
;
fp4e2m1x4
small_vec
[
4
];
};
using
OVec
=
Vec
<
OType
,
4
>
;
const
uint64_t
*
const
input_vectorized
=
reinterpret_cast
<
const
uint64_t
*>
(
input
);
OVec
*
output_vec
=
reinterpret_cast
<
OVec
*>
(
output
);
const
size_t
my_index
=
x
+
y
*
M
;
const
size_t
my_scale_index
=
x
+
y
*
scale_stride
;
const
size_t
my_output_index
=
(
x
+
y
*
M
)
*
4
;
fp4vec
value
;
value
.
vec
=
input_vectorized
[
my_index
];
fp8e4m3
scale
=
scales
[
my_scale_index
];
float
amax
=
*
tensor_amax
;
constexpr
float
factor_inv
=
1.0
/
(
6.0
*
448.0
);
float
final_scale
=
static_cast
<
float
>
(
scale
)
*
amax
*
factor_inv
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float4
current
=
static_cast
<
float4
>
(
value
.
small_vec
[
i
]);
OVec
out
;
out
.
data
.
elt
[
0
]
=
static_cast
<
OType
>
(
current
.
x
*
final_scale
);
out
.
data
.
elt
[
1
]
=
static_cast
<
OType
>
(
current
.
y
*
final_scale
);
out
.
data
.
elt
[
2
]
=
static_cast
<
OType
>
(
current
.
z
*
final_scale
);
out
.
data
.
elt
[
3
]
=
static_cast
<
OType
>
(
current
.
w
*
final_scale
);
output_vec
[
my_output_index
+
i
]
=
out
;
}
}
#endif // CUDA_VERSION
void
fp4_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#if CUDA_VERSION >= 12080
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
);
NVTE_CHECK
(
input
.
data
.
dtype
==
DType
::
kFloat4E2M1
,
"Input must have FP4 type."
);
NVTE_CHECK
(
is_high_precision_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
constexpr
int
FP4_BLOCK_SIZE
=
16
;
const
size_t
N
=
input
.
flat_first_dim
();
const
size_t
M
=
input
.
flat_last_dim
();
NVTE_CHECK
(
M
%
FP4_BLOCK_SIZE
==
0
,
"Last dimension of FP4 tensors needs to be divisible by "
,
FP4_BLOCK_SIZE
,
", but got "
,
input
.
data
.
shape
,
"."
);
const
size_t
Mread
=
M
/
FP4_BLOCK_SIZE
;
const
size_t
total
=
N
*
Mread
;
const
size_t
threads
=
512
;
const
size_t
blocks
=
DIVUP
(
total
,
threads
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
dequantize_fp4_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
.
data
.
dptr
,
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
fp8e4m3
*>
(
input
.
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
input
.
amax
.
dptr
),
N
,
Mread
,
input
.
scale_inv
.
shape
.
back
()););
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"CUDA 12.8 or higher is needed for FP4 calculation!"
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace dequantization
}
// namespace dequantization
namespace
detail
{
namespace
detail
{
...
@@ -347,17 +425,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
...
@@ -347,17 +425,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor
(
input
,
"cast_input"
);
CheckInputTensor
(
input
,
"cast_input"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
if
(
is_tensor_scaling
(
input
.
scaling_mode
))
{
switch
(
input
.
scaling_mode
)
{
dequantization
::
fp8_dequantize
(
input
,
output
,
stream
);
case
NVTE_DELAYED_TENSOR_SCALING
:
{
}
else
if
(
is_mxfp_scaling
(
input
.
scaling_mode
))
{
dequantization
::
fp8_dequantize
(
input
,
output
,
stream
);
if
(
is_supported_by_CC_100
())
{
break
;
dequantization
::
mxfp8_dequantize
(
input
,
output
,
stream
);
}
else
{
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
}
}
else
{
case
NVTE_MXFP8_1D_SCALING
:
{
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
if
(
is_supported_by_CC_100
())
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
dequantization
::
mxfp8_dequantize
(
input
,
output
,
stream
);
}
else
{
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
dequantization
::
fp4_dequantize
(
input
,
output
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
}
}
}
}
...
...
transformer_engine/common/util/logging.h
View file @
063ef88d
...
@@ -23,6 +23,8 @@
...
@@ -23,6 +23,8 @@
#endif // __HIP_PLATFORM_AMD__
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
#endif // NVTE_WITH_CUBLASMP
...
@@ -147,4 +149,12 @@
...
@@ -147,4 +149,12 @@
#endif // NVTE_WITH_CUBLASMP
#endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
transformer_engine/common/util/math.h
View file @
063ef88d
...
@@ -11,6 +11,11 @@ namespace transformer_engine {
...
@@ -11,6 +11,11 @@ namespace transformer_engine {
struct
Empty
{};
struct
Empty
{};
struct
ClampedSwiGLUParam
{
float
limit
;
float
alpha
=
1.702
f
;
// Default value for QuickGELU
};
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
gelu
(
const
IType
val
,
const
Empty
&
)
{
__device__
inline
OType
gelu
(
const
IType
val
,
const
Empty
&
)
{
const
float
cval
=
val
;
const
float
cval
=
val
;
...
@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
...
@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return
s
*
(
1.
f
-
s
);
return
s
*
(
1.
f
-
s
);
}
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
qgelu_with_alpha
(
const
IType
val
,
const
float
alpha
)
{
const
float
cval
=
val
;
Empty
e
=
{};
return
cval
*
sigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
);
}
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
qgelu
(
const
IType
val
,
const
Empty
&
e
)
{
__device__
inline
OType
qgelu
(
const
IType
val
,
const
Empty
&
e
)
{
return
qgelu_with_alpha
<
OType
,
IType
>
(
val
,
1.702
f
);
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dqgelu_with_alpha
(
const
IType
val
,
const
float
alpha
)
{
const
float
cval
=
val
;
const
float
cval
=
val
;
return
cval
*
sigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
);
Empty
e
=
{};
return
alpha
*
cval
*
dsigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
alpha
*
cval
,
e
);
}
}
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dqgelu
(
const
IType
val
,
const
Empty
&
e
)
{
__device__
inline
OType
dqgelu
(
const
IType
val
,
const
Empty
&
e
)
{
const
float
cval
=
val
;
return
dqgelu_with_alpha
<
OType
,
IType
>
(
val
,
1.702
f
);
return
1.702
f
*
cval
*
dsigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
1.702
f
*
cval
,
e
);
}
}
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
...
@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
...
@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return
cval
*
sigmoid
<
float
,
float
>
(
cval
,
e
);
return
cval
*
sigmoid
<
float
,
float
>
(
cval
,
e
);
}
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
clamped_silu
(
const
IType
val
,
const
ClampedSwiGLUParam
&
p
)
{
const
float
cval
=
min
(
p
.
limit
,
static_cast
<
float
>
(
val
));
// Clamping
return
qgelu_with_alpha
<
OType
,
float
>
(
cval
,
p
.
alpha
);
}
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
dsilu
(
const
IType
val
,
const
Empty
&
e
)
{
__device__
inline
OType
dsilu
(
const
IType
val
,
const
Empty
&
e
)
{
const
float
cval
=
val
;
const
float
cval
=
val
;
return
cval
*
dsigmoid
<
float
,
float
>
(
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
cval
,
e
);
return
cval
*
dsigmoid
<
float
,
float
>
(
cval
,
e
)
+
sigmoid
<
float
,
float
>
(
cval
,
e
);
}
}
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
clamped_dsilu
(
const
IType
val
,
const
ClampedSwiGLUParam
&
p
)
{
const
bool
dclamp_val
=
static_cast
<
float
>
(
val
)
<=
p
.
limit
;
const
float
clamp_val
=
min
(
static_cast
<
float
>
(
val
),
p
.
limit
);
const
float
dsilu_val
=
dqgelu_with_alpha
<
OType
,
float
>
(
clamp_val
,
p
.
alpha
);
return
dclamp_val
?
dsilu_val
:
0.0
f
;
}
template
<
typename
OType
,
typename
IType
>
template
<
typename
OType
,
typename
IType
>
__device__
inline
OType
relu
(
IType
value
,
const
Empty
&
)
{
__device__
inline
OType
relu
(
IType
value
,
const
Empty
&
)
{
return
fmaxf
(
value
,
0.
f
);
return
fmaxf
(
value
,
0.
f
);
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
#if CUDA_VERSION > 12080
namespace
nvfp4_transpose
{
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
using
namespace
ptx
;
using
nvfp4_scale_t
=
fp8e4m3
;
constexpr
size_t
SCALE_DIM
=
16
;
// NVFP4 block (x16 elts)
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_NUM
=
128
;
constexpr
size_t
SCALES_PER_CHUNK_Y
=
CHUNK_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_THREAD
=
2
*
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
SCALE_DIM
/
THREADS_NUM
;
constexpr
size_t
RNG_GENS_PER_THREAD
=
SCALES_PER_THREAD
/
4
;
// Each call generates 4x uint32_t random numbers
constexpr
size_t
TILE_DIM_Y
=
32
;
constexpr
size_t
TILE_DIM_X
=
128
;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr
size_t
SCALES_PER_TILE_Y
=
TILE_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
SCALES_PER_TILE_X
=
TILE_DIM_X
/
SCALE_DIM
;
// 128 / 16 = 8
constexpr
size_t
TILES_Y
=
CHUNK_DIM_Y
/
TILE_DIM_Y
;
constexpr
size_t
TILES_X
=
CHUNK_DIM_X
/
TILE_DIM_X
;
constexpr
size_t
STAGES
=
TILES_Y
*
TILES_X
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
BUFF_DIM_Y
=
TILE_DIM_Y
;
constexpr
size_t
BUFF_DIM_X
=
TILE_DIM_X
;
constexpr
size_t
BUFF_SIZE
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
BUFF_SIZE_TOTAL
=
BUFF_SIZE
*
BUFFS_NUM
;
// Input buffer (BF16)
constexpr
size_t
BUFF_IN_DIM_Y
=
BUFF_DIM_Y
;
constexpr
size_t
BUFF_IN_DIM_X
=
BUFF_DIM_X
;
constexpr
size_t
BUFF_IN_SIZE
=
BUFF_IN_DIM_Y
*
BUFF_IN_DIM_X
;
// Output buffer (NVFP4)
constexpr
size_t
BUFF_OUT_DIM_Y
=
BUFF_DIM_Y
;
constexpr
size_t
BUFF_OUT_DIM_X
=
(
BUFF_DIM_X
*
4
)
/
8
;
constexpr
size_t
BUFF_OUT_SIZE
=
BUFF_OUT_DIM_Y
*
BUFF_OUT_DIM_X
;
// Output transpose buffer (NVFP4)
constexpr
size_t
BUFF_OUT_T_DIM_Y
=
BUFF_DIM_X
;
constexpr
size_t
BUFF_OUT_T_DIM_X
=
(
BUFF_DIM_Y
*
4
)
/
8
;
constexpr
size_t
BUFF_OUT_T_SIZE
=
BUFF_OUT_T_DIM_Y
*
BUFF_OUT_T_DIM_X
;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr
size_t
PACK_SIZE
=
8
;
constexpr
size_t
WAVES
=
SCALE_DIM
/
PACK_SIZE
;
constexpr
size_t
SCALING_FACTORS_PER_TILE_X
=
TILE_DIM_X
/
SCALE_DIM
;
constexpr
size_t
THREADS_X_ROWWISE
=
SCALING_FACTORS_PER_TILE_X
;
// 128 / 16 = 8
constexpr
size_t
THREADS_Y_ROWWISE
=
THREADS_NUM
/
THREADS_X_ROWWISE
;
// 128 / 8 = 16
constexpr
size_t
ITERATIONS_NORMAL
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
// 32/ 16 = 2
constexpr
size_t
ITERATIONS_TRANSPOSE
=
BUFF_IN_DIM_Y
/
SCALE_DIM
;
constexpr
size_t
BUFF_OUT_IT_OFFSET
=
BUFF_OUT_T_DIM_X
/
ITERATIONS_TRANSPOSE
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM
;
// 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__
__forceinline__
nvfp4_scale_t
compute_decoding_scaling_factor
(
const
float
block_amax
,
const
float
S_enc
)
{
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using
namespace
detail
;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
// 6.0f;
const
float
S_dec_b
=
block_amax
/
fp4_max
*
S_enc
;
return
static_cast
<
nvfp4_scale_t
>
(
fminf
(
S_dec_b
,
TypeExtrema
<
float
>::
max
));
}
// Compute the global encode scale factor for a given global amax
__device__
__forceinline__
float
compute_global_encode_scaling_factor_FP4
(
const
float
global_amax
)
{
using
namespace
detail
;
constexpr
float
fp8_max
=
TypeExtrema
<
fp8e4m3
>::
max
;
// 448.0f;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
// 6.0f;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
TypeExtrema
<
float
>::
max
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.0
f
||
global_encode_scale
==
0.0
f
)
{
return
1.0
f
;
}
return
global_encode_scale
;
}
__device__
__forceinline__
uint32_t
get_rbits
(
RNG
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
if
(
rnd_idx
==
4
)
{
rnd_idx
=
0
;
curanddx
::
uniform_bits
dist
;
random_uint4
=
dist
.
generate4
(
rng
);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const
uint32_t
*
const
rbits_arr
=
reinterpret_cast
<
uint32_t
*>
(
&
random_uint4
);
const
uint32_t
rbits
=
rbits_arr
[
rnd_idx
++
];
return
rbits
;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3;
\n\t
"
// mind the shuffled elements order
"}"
:
"=h"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
template
<
bool
USE_STOCHASTIC_ROUNDING
>
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
return
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
in_4x
,
scale
,
rbits
);
}
else
{
return
mul_cvt_bf16_to_fp4_4x_with_rn
(
in_4x
,
scale
,
rbits
);
}
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4;
\n\t
"
// mind the shuffled elements order
"}"
:
"=h"
(
out_4x
)
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
template
<
bool
USE_STOCHASTIC_ROUNDING
>
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
return
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
in01
,
in23
,
scale
,
rbits
);
}
else
{
return
mul_cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
scale
,
rbits
);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
nvfp4_transpose_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
nvfp4_scale_t
*
const
scales_ptr
,
nvfp4_scale_t
*
const
scales_t_ptr
,
const
float
*
noop
,
const
float
*
const
amax_rowwise_ptr
,
const
float
*
const
amax_colwise_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride
,
const
size_t
scale_stride_t
,
const
size_t
*
rng_state
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
const
size_t
rng_sequence
=
threadIdx
.
x
+
blockIdx
.
x
*
THREADS_NUM
+
blockIdx
.
y
*
gridDim
.
x
*
THREADS_NUM
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
USE_STOCHASTIC_ROUNDING
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
;
const
size_t
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_X_t
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_rows
=
rows
-
block_offset_Y
;
const
size_t
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
SCALES_PER_CHUNK_X
;
const
size_t
scales_block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
scales_block_offset_X_t
=
blockIdx
.
y
*
SCALES_PER_CHUNK_Y
;
const
size_t
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
size_t
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
size_t
tid_X_colwise
=
threadIdx
.
x
;
const
size_t
tid_Y_t
=
tid_X_colwise
;
// const size_t tid_X_t = 0;
const
size_t
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
size_t
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM
;
const
size_t
thread_offset_X_colwise
=
tid_X_colwise
;
const
size_t
row_base_rowwise
=
block_offset_Y
+
thread_offset_Y_rowwise
;
const
size_t
row_base_colwise
=
block_offset_Y
;
const
size_t
col_base_colwise
=
block_offset_X
+
thread_offset_X_colwise
;
const
bool
col_out_of_bounds_colwise
=
(
col_base_colwise
>=
cols
);
const
size_t
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
size_t
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
size_t
scales_offset_Y_t
=
scales_block_offset_Y_t
+
tid_Y_t
;
const
size_t
scales_offset_X_t
=
scales_block_offset_X_t
;
const
size_t
SFs_per_row
=
cols
/
SCALE_DIM
;
const
bool
rowwise_scale_is_within_bounds_X
=
scales_offset_X_rowwise
<
SFs_per_row
;
const
bool
colwise_scale_is_within_bounds_Y
=
scales_offset_Y_t
<
cols
;
// Helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_colwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_rowwise_scales
=
0
;
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
fp4e2m1x2
*
out_t_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
nvfp4_scale_t
*
out_rowwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
nvfp4_scale_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const
float
S_enc_rowwise
=
(
amax_rowwise_ptr
==
nullptr
)
?
1.0
f
:
compute_global_encode_scaling_factor_FP4
(
*
amax_rowwise_ptr
);
// NOTE: This is to match with how emulation code was written.
const
float
S_dec_rowwise
=
1.0
/
S_enc_rowwise
;
const
float
S_enc_colwise
=
(
amax_colwise_ptr
==
nullptr
)
?
S_enc_rowwise
:
compute_global_encode_scaling_factor_FP4
(
*
amax_colwise_ptr
);
const
float
S_dec_colwise
=
1.0
/
S_enc_colwise
;
float
thread_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_NUM
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
size_t
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
size_t
buff
=
stage
%
BUFFS_NUM
;
const
size_t
next_stage
=
stage
+
1
;
const
size_t
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
size_t
buff_offset_in
=
buff
*
BUFF_IN_SIZE
;
const
size_t
buff_offset_out
=
buff
*
BUFF_OUT_SIZE
;
const
size_t
buff_offset_out_t
=
buff
*
BUFF_OUT_T_SIZE
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
size_t
next_buff
=
next_stage
%
BUFFS_NUM
;
const
size_t
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
size_t
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
next_buff_offset
=
next_buff
*
BUFF_IN_SIZE
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
// COLWISE scaling
if
constexpr
(
RETURN_TRANSPOSE
)
{
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_TRANSPOSE
;
++
it
)
{
const
size_t
in_thread_offset_Y
=
0
+
it
*
SCALE_DIM
;
const
size_t
in_thread_offset_X
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_Y
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_X
=
0
+
it
*
BUFF_OUT_IT_OFFSET
;
const
size_t
shmem_offset_base_colwise_in
=
buff_offset_in
+
in_thread_offset_Y
*
BUFF_IN_DIM_X
+
in_thread_offset_X
;
const
size_t
shmem_offset_base_colwise_out_t
=
buff_offset_out_t
+
out_t_thread_offset_Y
*
BUFF_OUT_T_DIM_X
+
out_t_thread_offset_X
;
block_amax
=
0.0
f
;
float
in_compute_colwise
[
SCALE_DIM
];
IType
in_colwise_IType
[
SCALE_DIM
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType
block_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
block_amax_f16
=
__hmax
(
block_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
block_amax
=
static_cast
<
float
>
(
block_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
+
i
>=
rows
);
const
bool
out_of_bounds
=
(
col_out_of_bounds_colwise
||
row_out_of_bounds_colwise
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_colwise
);
// Store scaling factors through SHMEM
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
+
stage
*
ITERATIONS_TRANSPOSE
+
it
;
out_colwise_scales_sh
[
scale_idx_sh
]
=
S_dec_b_fp8
;
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_colwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
fp4e2m1x4
regs
[
SCALE_DIM
/
4
];
#pragma unroll
for
(
int
e
=
0
;
e
<
SCALE_DIM
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_colwise_IType
[
4
*
e
]);
regs
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
float2
in01
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
]);
const
float2
in23
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
+
2
]);
regs
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
int
group
=
thread_lane
/
16
;
uint32_t
val
[
2
];
uint32_t
*
regs_4x
=
reinterpret_cast
<
uint32_t
*>
(
regs
);
// Helps reducing bank conflicts
switch
(
group
)
{
case
0
:
val
[
0
]
=
regs_4x
[
0
];
val
[
1
]
=
regs_4x
[
1
];
break
;
case
1
:
val
[
0
]
=
regs_4x
[
1
];
val
[
1
]
=
regs_4x
[
0
];
break
;
}
uint32_t
*
out_t_data_sh_as_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
&
out_t_data_sh
[
shmem_offset_base_colwise_out_t
]);
out_t_data_sh_as_uint32_t
[
group
]
=
val
[
0
];
// idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t
[(
group
+
1
)
&
1
]
=
val
[
1
];
// idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const
size_t
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_NORMAL
;
++
it
)
{
const
size_t
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
size_t
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
const
size_t
it_offset_Y
=
stage_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
block_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
(
!
out_of_bounds
)
{
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
size_t
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_rowwise
);
// Check boundaries
const
size_t
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage
*
BUFF_DIM_Y
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx_global
=
scales_offset_Y
*
scale_stride
+
scales_offset_X
;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const
bool
rowwise_scale_is_within_bounds_Y
=
(
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
)
<
chunk_rows
;
if
(
rowwise_scale_is_within_bounds_X
&&
rowwise_scale_is_within_bounds_Y
)
{
scales_ptr
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_rowwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_IType
[
w
].
data
.
elt
[
2
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_cached
[
w
].
data
.
elt
[
4
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
const
float2
in01
=
make_float2
(
in_compute_rowwise
[
j
],
in_compute_rowwise
[
j
+
1
]);
const
float2
in23
=
make_float2
(
in_compute_rowwise
[
j
+
2
],
in_compute_rowwise
[
j
+
3
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
global_offset_Y_t
=
block_offset_Y_t
;
const
size_t
global_offset_X_t
=
block_offset_X_t
+
stage_offset_Y
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_data_sh
[
buff_offset_out
]));
if
constexpr
(
RETURN_TRANSPOSE
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
global_offset_X_t
,
global_offset_Y_t
,
reinterpret_cast
<
uint64_t
*>
(
&
out_t_data_sh
[
buff_offset_out_t
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
// end of stages
// Vectorized store scaling factors through SHMEM
if
(
RETURN_TRANSPOSE
&&
colwise_scale_is_within_bounds_Y
)
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_Y
>
;
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
;
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
&
out_colwise_scales_sh
[
scale_idx_sh
]);
const
size_t
scale_idx_global
=
scales_offset_Y_t
*
scale_stride_t
+
scales_offset_X_t
;
const
size_t
count
=
// number of scales in Y dimension of this chunk
(
chunk_rows
>=
CHUNK_DIM_Y
)
?
SCALES_PER_CHUNK_Y
:
(
chunk_rows
/
SCALE_DIM
);
nvfp4_scale_t
*
dst
=
&
scales_t_ptr
[
scale_idx_global
];
constexpr
size_t
vec_bytes
=
SCALES_PER_CHUNK_Y
*
sizeof
(
nvfp4_scale_t
);
if
(
count
==
SCALES_PER_CHUNK_Y
&&
(
reinterpret_cast
<
uintptr_t
>
(
dst
)
%
vec_bytes
==
0
))
{
// Fast path: vectorized store when destination is properly aligned
scales_vec
.
store_to
(
dst
);
}
else
{
// Safe path: element-wise store for tails or unaligned destinations
scales_vec
.
store_to_elts
(
dst
,
0
,
count
);
}
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#else
NVTE_DEVICE_ERROR
(
"sm_100 or higher is required."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
nvfp4_transpose_kernel_2D
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
nvfp4_scale_t
*
const
scales_ptr
,
nvfp4_scale_t
*
const
scales_t_ptr
,
const
float
*
noop
,
const
float
*
const
amax_rowwise_ptr
,
const
float
*
const
amax_colwise_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride
,
const
size_t
scale_stride_t
,
const
size_t
*
rng_state
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
const
size_t
rng_sequence
=
threadIdx
.
x
+
blockIdx
.
x
*
THREADS_NUM
+
blockIdx
.
y
*
gridDim
.
x
*
THREADS_NUM
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
USE_STOCHASTIC_ROUNDING
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr
size_t
BLOCK_DIM
=
16
;
constexpr
size_t
BLOCKS_PER_TILE_Y
=
TILE_DIM_Y
/
BLOCK_DIM
;
// 32/16 = 2
constexpr
size_t
BLOCKS_PER_TILE_X
=
TILE_DIM_X
/
BLOCK_DIM
;
// 128/16 = 8
constexpr
size_t
ITERATIONS_BLOCK
=
2
;
// iterations to calculate 2d block amaxes of 1 tile
constexpr
size_t
BLOCKS_PER_WARP
=
BLOCKS_PER_TILE_X
/
(
THREADS_NUM
/
32
);
// 8 / (128/32) = 2
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
;
const
size_t
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
block_offset_X_t
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
chunk_rows
=
rows
-
block_offset_Y
;
const
size_t
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
size_t
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
SCALES_PER_CHUNK_X
;
const
size_t
scales_block_offset_Y_t
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
size_t
scales_block_offset_X_t
=
blockIdx
.
y
*
SCALES_PER_CHUNK_Y
;
const
size_t
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
size_t
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
size_t
tid_X_colwise
=
threadIdx
.
x
;
const
size_t
tid_Y_t
=
tid_X_colwise
;
const
size_t
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
size_t
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM
;
const
size_t
thread_offset_X_colwise
=
tid_X_colwise
;
const
size_t
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
size_t
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
size_t
scales_offset_Y_t
=
scales_block_offset_Y_t
+
tid_Y_t
;
const
size_t
scales_offset_X_t
=
scales_block_offset_X_t
;
const
size_t
SFs_per_row
=
cols
/
SCALE_DIM
;
const
bool
rowwise_scale_is_within_bounds_X
=
scales_offset_X_rowwise
<
SFs_per_row
;
const
bool
colwise_scale_is_within_bounds_Y
=
scales_offset_Y_t
<
cols
;
// Helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_colwise_data
=
buff_size_aligned_out
;
constexpr
size_t
out_mem_rowwise_scales
=
0
;
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
fp4e2m1x2
*
out_t_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
nvfp4_scale_t
*
out_rowwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
nvfp4_scale_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const
float
S_enc_rowwise
=
(
amax_rowwise_ptr
==
nullptr
)
?
1.0
f
:
compute_global_encode_scaling_factor_FP4
(
*
amax_rowwise_ptr
);
// NOTE: This is to match with how emulation code was written.
const
float
S_dec_rowwise
=
1.0
/
S_enc_rowwise
;
const
float
S_enc_colwise
=
(
amax_colwise_ptr
==
nullptr
)
?
S_enc_rowwise
:
compute_global_encode_scaling_factor_FP4
(
*
amax_colwise_ptr
);
const
float
S_dec_colwise
=
1.0
/
S_enc_colwise
;
const
size_t
warp_id
=
threadIdx
.
x
/
32
;
const
size_t
lane_id
=
threadIdx
.
x
%
32
;
float
thread_amax
=
0.0
f
;
const
size_t
block_in_warp
=
lane_id
/
BLOCKS_PER_WARP
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
__shared__
__align__
(
16
)
float
block_amax_matrix
[
BLOCKS_PER_TILE_Y
][
BLOCKS_PER_TILE_X
+
1
];
// Helper function for warp reduction
auto
warp_reduce_amax
=
[](
float
thread_amax
,
int
block_in_warp
)
->
float
{
#pragma unroll
for
(
int
delta
=
8
;
delta
>=
1
;
delta
/=
2
)
{
float
other_amax
=
__shfl_xor_sync
(
0xffffffff
,
thread_amax
,
delta
);
thread_amax
=
fmaxf
(
thread_amax
,
other_amax
);
}
return
thread_amax
;
};
initialize_barriers
<
STAGES
,
THREADS_NUM
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
size_t
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
size_t
buff
=
stage
%
BUFFS_NUM
;
const
size_t
next_stage
=
stage
+
1
;
const
size_t
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
size_t
buff_offset_in
=
buff
*
BUFF_IN_SIZE
;
const
size_t
buff_offset_out
=
buff
*
BUFF_OUT_SIZE
;
const
size_t
buff_offset_out_t
=
buff
*
BUFF_OUT_T_SIZE
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
size_t
next_buff
=
next_stage
%
BUFFS_NUM
;
const
size_t
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
size_t
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
next_buff_offset
=
next_buff
*
BUFF_IN_SIZE
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
#pragma unroll
for
(
size_t
block_iter
=
0
;
block_iter
<
ITERATIONS_BLOCK
;
++
block_iter
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
const
size_t
block_in_tile_y
=
block_iter
;
const
size_t
block_in_tile_x
=
threadIdx
.
x
/
BLOCK_DIM
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
for
(
int
elem
=
0
;
elem
<
BLOCK_DIM
;
elem
+=
2
)
{
const
size_t
elem_0_row
=
block_iter
*
BLOCK_DIM
+
elem
;
const
size_t
elem_1_row
=
elem_0_row
+
1
;
const
size_t
elem_0_col
=
warp_id
*
BLOCKS_PER_WARP
*
BLOCK_DIM
+
lane_id
;
const
size_t
elem_1_col
=
elem_0_col
;
const
size_t
shmem_offset_0
=
buff_offset_in
+
elem_0_row
*
BUFF_IN_DIM_X
+
elem_0_col
;
const
size_t
shmem_offset_1
=
buff_offset_in
+
elem_1_row
*
BUFF_IN_DIM_X
+
elem_1_col
;
IType2
val_2x
;
val_2x
.
x
=
in_sh
[
shmem_offset_0
];
val_2x
.
y
=
in_sh
[
shmem_offset_1
];
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
val_2x
);
}
thread_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
{
for
(
int
elem
=
0
;
elem
<
BLOCK_DIM
;
++
elem
)
{
const
size_t
elem_row
=
block_iter
*
BLOCK_DIM
+
elem
;
const
size_t
elem_col
=
warp_id
*
BLOCKS_PER_WARP
*
BLOCK_DIM
+
lane_id
;
// Bounds checking
const
bool
row_out_of_bounds
=
(
block_offset_Y
+
stage_offset_Y
+
elem_row
>=
rows
);
const
bool
col_out_of_bounds
=
(
block_offset_X
+
elem_col
>=
cols
);
if
(
!
row_out_of_bounds
&&
!
col_out_of_bounds
)
{
const
size_t
shmem_offset
=
buff_offset_in
+
elem_row
*
BUFF_IN_DIM_X
+
elem_col
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset
]
=
static_cast
<
IType
>
(
elt
);
}
thread_amax
=
fmaxf
(
thread_amax
,
fabsf
(
elt
));
}
}
}
// Warp reduction to get block amax
block_amax
=
warp_reduce_amax
(
thread_amax
,
block_in_warp
);
if
(
lane_id
==
0
||
lane_id
==
16
)
{
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
]
=
block_amax
;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads
();
// COLWISE scaling
if
constexpr
(
RETURN_TRANSPOSE
)
{
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_TRANSPOSE
;
++
it
)
{
const
size_t
block_in_tile_y
=
it
;
const
size_t
block_in_tile_x
=
threadIdx
.
x
/
BLOCK_DIM
;
const
size_t
in_thread_offset_Y
=
0
+
it
*
SCALE_DIM
;
const
size_t
in_thread_offset_X
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_Y
=
thread_offset_X_colwise
;
const
size_t
out_t_thread_offset_X
=
0
+
it
*
BUFF_OUT_IT_OFFSET
;
const
size_t
shmem_offset_base_colwise_in
=
buff_offset_in
+
in_thread_offset_Y
*
BUFF_IN_DIM_X
+
in_thread_offset_X
;
const
size_t
shmem_offset_base_colwise_out_t
=
buff_offset_out_t
+
out_t_thread_offset_Y
*
BUFF_OUT_T_DIM_X
+
out_t_thread_offset_X
;
block_amax
=
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
];
float
in_compute_colwise
[
SCALE_DIM
];
IType
in_colwise_IType
[
SCALE_DIM
];
// 3. Scale elements
// Load data in
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
}
}
else
{
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise_in
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_colwise
);
// // Store scaling factors through SHMEM
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
+
stage
*
ITERATIONS_TRANSPOSE
+
it
;
out_colwise_scales_sh
[
scale_idx_sh
]
=
S_dec_b_fp8
;
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_colwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
fp4e2m1x4
regs
[
SCALE_DIM
/
4
];
#pragma unroll
for
(
int
e
=
0
;
e
<
SCALE_DIM
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_colwise_IType
[
4
*
e
]);
regs
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
float2
in01
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
]);
const
float2
in23
=
*
reinterpret_cast
<
float2
*>
(
&
in_compute_colwise
[
4
*
e
+
2
]);
regs
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
int
group
=
thread_lane
/
16
;
uint32_t
val
[
2
];
uint32_t
*
regs_4x
=
reinterpret_cast
<
uint32_t
*>
(
regs
);
// Helps reducing bank conflicts
switch
(
group
)
{
case
0
:
val
[
0
]
=
regs_4x
[
0
];
val
[
1
]
=
regs_4x
[
1
];
break
;
case
1
:
val
[
0
]
=
regs_4x
[
1
];
val
[
1
]
=
regs_4x
[
0
];
break
;
}
uint32_t
*
out_t_data_sh_as_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
&
out_t_data_sh
[
shmem_offset_base_colwise_out_t
]);
out_t_data_sh_as_uint32_t
[
group
]
=
val
[
0
];
// idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t
[(
group
+
1
)
&
1
]
=
val
[
1
];
// idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const
size_t
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
size_t
it
=
0
;
it
<
ITERATIONS_NORMAL
;
++
it
)
{
const
size_t
block_in_tile_y
=
it
;
const
size_t
block_in_tile_x
=
tid_X_rowwise
;
const
size_t
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
size_t
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
block_amax
=
block_amax_matrix
[
block_in_tile_y
][
block_in_tile_x
];
float
in_compute_rowwise
[
SCALE_DIM
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
}
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
size_t
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_rowwise
);
// Check boundaries
const
size_t
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage
*
BUFF_DIM_Y
+
it
*
THREADS_Y_ROWWISE
;
const
size_t
scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx_global
=
scales_offset_Y
*
scale_stride
+
scales_offset_X
;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const
bool
rowwise_scale_is_within_bounds_Y
=
(
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
)
<
chunk_rows
;
if
(
rowwise_scale_is_within_bounds_X
&&
rowwise_scale_is_within_bounds_Y
)
{
scales_ptr
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
// Compute "correct" per-block encoding scaling factor
constexpr
float
float_max
=
detail
::
TypeExtrema
<
float
>::
max
;
const
float
block_scale_inverse
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_b_fp8
)
*
S_dec_rowwise
),
float_max
);
// S_enc_b_fp8
const
float2
block_scale_inverse_2x
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
const
uint32_t
rbits
=
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_IType
[
w
].
data
.
elt
[
2
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
const
uint64_t
elts
=
*
reinterpret_cast
<
uint64_t
*>
(
&
in_cached
[
w
].
data
.
elt
[
4
*
e
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_bf16_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
elts
,
block_scale_inverse_2x
,
rbits
);
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
const
float2
in01
=
make_float2
(
in_compute_rowwise
[
j
],
in_compute_rowwise
[
j
+
1
]);
const
float2
in23
=
make_float2
(
in_compute_rowwise
[
j
+
2
],
in_compute_rowwise
[
j
+
3
]);
out
.
data
.
elt
[
e
]
=
mul_cvt_fp32_to_fp4_4x
<
USE_STOCHASTIC_ROUNDING
>
(
in01
,
in23
,
block_scale_inverse_2x
,
rbits
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
global_offset_Y_t
=
block_offset_Y_t
;
const
size_t
global_offset_X_t
=
block_offset_X_t
+
stage_offset_Y
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_data_sh
[
buff_offset_out
]));
if
constexpr
(
RETURN_TRANSPOSE
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
global_offset_X_t
,
global_offset_Y_t
,
reinterpret_cast
<
uint64_t
*>
(
&
out_t_data_sh
[
buff_offset_out_t
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
// end of stages
// Vectorized store scaling factors through SHMEM
if
(
RETURN_TRANSPOSE
&&
colwise_scale_is_within_bounds_Y
)
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_Y
>
;
const
size_t
scale_idx_sh
=
tid_Y_t
*
SCALES_PER_CHUNK_Y
;
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
&
out_colwise_scales_sh
[
scale_idx_sh
]);
const
size_t
scale_idx_global
=
scales_offset_Y_t
*
scale_stride_t
+
scales_offset_X_t
;
const
size_t
count
=
// number of scales in Y dimension of this chunk
(
chunk_rows
>=
CHUNK_DIM_Y
)
?
SCALES_PER_CHUNK_Y
:
(
chunk_rows
/
SCALE_DIM
);
nvfp4_scale_t
*
dst
=
&
scales_t_ptr
[
scale_idx_global
];
constexpr
size_t
vec_bytes
=
SCALES_PER_CHUNK_Y
*
sizeof
(
nvfp4_scale_t
);
if
(
count
==
SCALES_PER_CHUNK_Y
&&
(
reinterpret_cast
<
uintptr_t
>
(
dst
)
%
vec_bytes
==
0
))
{
// Fast path: vectorized store when destination is properly aligned
scales_vec
.
store_to
(
dst
);
}
else
{
// Safe path: element-wise store for tails or unaligned destinations
scales_vec
.
store_to_elts
(
dst
,
0
,
count
);
}
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
bool
use_2d_quantization
>
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if CUDA_VERSION > 12080
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool
return_transpose
=
output
->
has_columnwise_data
();
using
namespace
nvfp4_transpose
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
,
false
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 output tensor must be allocated."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
if
(
return_transpose
)
{
NVTE_CHECK
(
output
->
has_columnwise_data
(),
"NVFP4 transposed output tensor must be allocated."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
columnwise_data
.
dtype
),
"Transposed output must have FP4 type."
);
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Transposed scaling tensor must be allocated"
);
}
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
NVTE_CHECK
(
rows
%
32
==
0
,
"Number of tensor rows must be a multiple of 32"
);
// 16B alignment for TMA
NVTE_CHECK
(
cols
%
32
==
0
,
"Number of tensor cols must be a multiple of 32"
);
// 16B alignment for TMA
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
size_t
block_size
=
THREADS_NUM
;
const
size_t
scale_stride
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_transpose
=
return_transpose
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
0
;
nvfp4_scale_t
*
const
scales_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
scale_inv
.
dptr
);
nvfp4_scale_t
*
const
scales_transpose_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
columnwise_scale_inv
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
amax_rowwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
amax
.
dptr
);
const
float
*
const
amax_colwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
columnwise_amax
.
dptr
);
const
NVTETensor
rng_state_tensor
=
(
quant_config
!=
nullptr
)
?
quant_config
->
rng_state
:
nullptr
;
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
using
IType
=
bf16
;
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_transpose
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
return_transpose
)
{
create_2D_tensor_map
(
tensor_map_output_transpose
,
output
->
columnwise_data
,
cols
,
rows
,
BUFF_DIM_X
,
BUFF_DIM_Y
,
rows
,
0
,
4
);
}
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
16
*
sizeof
(
nvfp4_scale_t
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_data_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_data_transpose_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_scales_transpose_mem
=
buff_size_scales
;
constexpr
size_t
out_mem
=
out_data_mem
+
out_data_transpose_mem
;
constexpr
size_t
dshmem_size
=
in_mem
+
out_mem
+
out_scales_transpose_mem
+
TMA_SHMEM_ALIGNMENT
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
USE_STOCHASTIC_ROUNDING
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
RETURN_TRANSPOSE
,
{
auto
kernel
=
nvfp4_transpose_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
USE_STOCHASTIC_ROUNDING
,
RETURN_TRANSPOSE
>
;
if
constexpr
(
use_2d_quantization
)
{
kernel
=
nvfp4_transpose_kernel_2D
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
USE_STOCHASTIC_ROUNDING
,
RETURN_TRANSPOSE
>
;
}
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
kernel
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output
,
tensor_map_output_transpose
,
scales_ptr
,
scales_transpose_ptr
,
noop_ptr
,
amax_rowwise_ptr
,
amax_colwise_ptr
,
rows
,
cols
,
scale_stride
,
scale_stride_transpose
,
rng_state
);
}););
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION > 12080
}
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
transformer_engine/common/util/ptx.cuh
View file @
063ef88d
...
@@ -14,6 +14,10 @@
...
@@ -14,6 +14,10 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
ptx
{
namespace
ptx
{
...
@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
...
@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return
__int_as_float
(
biased_exp
<<
FP32_MANTISSA_BITS
);
return
__int_as_float
(
biased_exp
<<
FP32_MANTISSA_BITS
);
}
}
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
#if
((__
CUDA_ARCH_HAS_FEATURE_
_(
SM10
0
_ALL
)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
#if CUDA_ARCH_HAS_FEATURE_SM10
X
_ALL
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t
out
;
uint16_t
out
;
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
...
@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
...
@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T
y
;
T
y
;
};
};
template
<
typename
T
>
struct
FPx4
{
T
x1
;
T
x2
;
T
x3
;
T
x4
;
};
template
<
typename
T
>
struct
Type2x
{};
template
<
>
struct
Type2x
<
float
>
{
using
type
=
float2
;
};
template
<
>
struct
Type2x
<
bf16
>
{
using
type
=
__nv_bfloat162
;
};
template
<
>
struct
Type2x
<
fp16
>
{
using
type
=
__half2
;
};
using
floatx2
=
FPx2
<
float
>
;
using
floatx2
=
FPx2
<
float
>
;
using
bf16x2
=
FPx2
<
bf16
>
;
using
bf16x2
=
FPx2
<
bf16
>
;
using
fp16x2
=
FPx2
<
fp16
>
;
using
fp16x2
=
FPx2
<
fp16
>
;
using
fp8e4m3x2
=
FPx2
<
fp8e4m3
>
;
using
fp8e4m3x2
=
FPx2
<
fp8e4m3
>
;
using
fp8e5m2x2
=
FPx2
<
fp8e5m2
>
;
using
fp8e5m2x2
=
FPx2
<
fp8e5m2
>
;
using
floatx4
=
FPx4
<
float
>
;
using
bf16x4
=
FPx4
<
bf16
>
;
using
fp16x4
=
FPx4
<
fp16
>
;
using
fp8e4m3x4
=
FPx4
<
fp8e4m3
>
;
using
fp8e5m2x4
=
FPx4
<
fp8e5m2
>
;
static_assert
(
sizeof
(
floatx2
)
==
8
);
static_assert
(
sizeof
(
floatx2
)
==
8
);
static_assert
(
sizeof
(
bf16x2
)
==
4
);
static_assert
(
sizeof
(
bf16x2
)
==
4
);
static_assert
(
sizeof
(
fp16x2
)
==
4
);
static_assert
(
sizeof
(
fp16x2
)
==
4
);
static_assert
(
sizeof
(
fp8e4m3x2
)
==
2
);
static_assert
(
sizeof
(
fp8e4m3x2
)
==
2
);
static_assert
(
sizeof
(
fp8e5m2x2
)
==
2
);
static_assert
(
sizeof
(
fp8e5m2x2
)
==
2
);
#if CUDA_VERSION >= 12080
using
fp4e2m1
=
__nv_fp4_e2m1
;
using
fp4e2m1x2
=
__nv_fp4x2_e2m1
;
using
fp4e2m1x4
=
__nv_fp4x4_e2m1
;
static_assert
(
sizeof
(
fp4e2m1x2
)
==
1
);
static_assert
(
sizeof
(
fp4e2m1x4
)
==
2
);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template
<
typename
Tx2
>
__device__
__forceinline__
void
mul_cvt_4x
(
fp4e2m1x4
&
out
,
const
Tx2
&
in01
,
const
Tx2
&
in23
,
const
float
scale
)
{
const
float
x0
=
static_cast
<
float
>
(
in01
.
x
)
*
scale
;
const
float
x1
=
static_cast
<
float
>
(
in01
.
y
)
*
scale
;
const
float
x2
=
static_cast
<
float
>
(
in23
.
x
)
*
scale
;
const
float
x3
=
static_cast
<
float
>
(
in23
.
y
)
*
scale
;
out
=
fp4e2m1x4
(
make_float4
(
x0
,
x1
,
x2
,
x3
));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2)
// SIMD like "Fused" cast + multiplication (x2)
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
floatx2
&
in
,
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
floatx2
&
in
,
const
floatx2
&
scale
)
{
const
floatx2
&
scale
)
{
...
@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
...
@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
}
}
#endif //
#if
(defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
// namespace ptx
}
// namespace ptx
...
...
transformer_engine/common/util/pybind_helper.h
View file @
063ef88d
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.value("kInt8", transformer_engine::DType::kInt8); \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
...
@@ -41,6 +42,10 @@
...
@@ -41,6 +42,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
...
...
transformer_engine/common/util/vectorized_pointwise.h
View file @
063ef88d
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#include "../common.h"
#include "../common.h"
#include "../utils.cuh"
#include "../utils.cuh"
#include "math.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
/* \brief Helper class that enables storing multiple values of type DType
/* \brief Helper class that enables storing multiple values of type DType
...
@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
...
@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename
OutputType
>
typename
OutputType
>
void
VectorizedUnaryKernelLauncher
(
const
InputType
*
input
,
const
fp32
*
noop
,
OutputType
*
output
,
void
VectorizedUnaryKernelLauncher
(
const
InputType
*
input
,
const
fp32
*
noop
,
OutputType
*
output
,
const
fp32
*
scale
,
fp32
*
amax
,
fp32
*
scale_inv
,
const
size_t
N
,
const
fp32
*
scale
,
fp32
*
amax
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
params
,
cudaStream_t
stream
)
{
const
Param
&
params
,
cudaStream_t
stream
)
{
if
(
N
!=
0
)
{
if
(
N
!=
0
)
{
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
output
);
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
output
);
...
@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
...
@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename
InputTypeGrad
,
typename
OutputType
>
typename
InputTypeGrad
,
typename
OutputType
>
void
VectorizedUnaryGradKernelLauncher
(
const
InputTypeGrad
*
grad
,
const
InputType
*
input
,
void
VectorizedUnaryGradKernelLauncher
(
const
InputTypeGrad
*
grad
,
const
InputType
*
input
,
OutputType
*
output
,
const
fp32
*
scale
,
fp32
*
amax
,
OutputType
*
output
,
const
fp32
*
scale
,
fp32
*
amax
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
params
,
fp32
*
scale_inv
,
const
size_t
N
,
const
Param
&
params
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
if
(
N
!=
0
)
{
if
(
N
!=
0
)
{
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
grad
,
output
);
auto
align
=
CheckAlignment
(
N
,
nvec
,
input
,
grad
,
output
);
...
@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__
...
@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
const
ComputeType
val
=
static_cast
<
ComputeType
>
(
loader0
.
separate
()[
i
]);
const
ComputeType
val
=
static_cast
<
ComputeType
>
(
loader0
.
separate
()[
i
]);
const
ComputeType
val2
=
static_cast
<
ComputeType
>
(
loader1
.
separate
()[
i
]);
ComputeType
val2
=
static_cast
<
ComputeType
>
(
loader1
.
separate
()[
i
]);
if
constexpr
(
std
::
is_same
<
Param
,
ClampedSwiGLUParam
>::
value
)
{
// Clamp the gated value and add 1 at the end
ComputeType
limit
=
p
.
limit
;
val2
=
std
::
min
(
std
::
max
(
-
limit
,
val2
),
limit
)
+
1
;
}
ComputeType
temp
=
static_cast
<
ComputeType
>
(
Activation
(
val
,
p
)
*
val2
);
ComputeType
temp
=
static_cast
<
ComputeType
>
(
Activation
(
val
,
p
)
*
val2
);
if
(
requires_amax
)
{
if
(
requires_amax
)
{
__builtin_assume
(
max
>=
0
);
__builtin_assume
(
max
>=
0
);
...
@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__
...
@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
const
ComputeType
grad_val
=
static_cast
<
ComputeType
>
(
grad_loader
.
separate
()[
i
]);
const
ComputeType
grad_val
=
static_cast
<
ComputeType
>
(
grad_loader
.
separate
()[
i
]);
const
ComputeType
gelu_in
=
static_cast
<
ComputeType
>
(
input_loader0
.
separate
()[
i
]);
const
ComputeType
gelu_in
=
static_cast
<
ComputeType
>
(
input_loader0
.
separate
()[
i
]);
const
ComputeType
gate_in
=
static_cast
<
ComputeType
>
(
input_loader1
.
separate
()[
i
]);
ComputeType
gate_in
=
static_cast
<
ComputeType
>
(
input_loader1
.
separate
()[
i
]);
bool
dgate_in
=
true
;
if
constexpr
(
std
::
is_same
<
Param
,
ClampedSwiGLUParam
>::
value
)
{
// In case of GPT OSS, clamp the activation and gate values
const
ComputeType
limit
=
p
.
limit
;
dgate_in
=
gate_in
<=
limit
&&
gate_in
>=
-
limit
;
// Derivative of clamp
gate_in
=
std
::
min
(
std
::
max
(
-
limit
,
gate_in
),
limit
)
+
1.0
f
;
}
ComputeType
after_dgelu
=
Dactivation
(
gelu_in
,
p
)
*
grad_val
*
gate_in
;
ComputeType
after_dgelu
=
Dactivation
(
gelu_in
,
p
)
*
grad_val
*
gate_in
;
ComputeType
after_dgate
=
grad_val
*
Activation
(
gelu_in
,
p
);
ComputeType
after_dgate
=
dgate_in
?
grad_val
*
Activation
(
gelu_in
,
p
)
:
0.0
f
;
if
(
requires_amax
)
{
if
(
requires_amax
)
{
__builtin_assume
(
max
>=
0
);
__builtin_assume
(
max
>=
0
);
...
...
transformer_engine/common/utils.cuh
View file @
063ef88d
...
@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
...
@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
// NOLINT(*)
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
// NOLINT(*)
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
...
...
transformer_engine/debug/features/fake_quant.py
View file @
063ef88d
...
@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
...
@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
from
transformer_engine.pytorch.tensor
import
Quantizer
from
transformer_engine.pytorch.tensor
import
Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.
fp8
import
_default_sf_compute
from
transformer_engine.pytorch.
quantization
import
_default_sf_compute
def
fake_quantize
(
tensor
:
torch
.
Tensor
,
fp8_format
:
tex
.
DType
,
out
=
None
):
def
fake_quantize
(
tensor
:
torch
.
Tensor
,
fp8_format
:
tex
.
DType
,
out
=
None
):
...
...
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
063ef88d
...
@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
...
@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
for
stat
in
config
[
"stats"
]:
for
stat
in
config
[
"stats"
]:
self
.
check_if_stat_is_supported
(
stat
,
recipe_name
)
self
.
check_if_stat_is_supported
(
stat
,
recipe_name
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
options
=
(
config
.
get
(
"
start_step
"
,
None
)
,
start_step
,
config
.
get
(
"end_step"
,
None
)
,
end_step
,
config
.
get
(
"
start_end_list
"
,
None
)
,
start_end_list
,
"fp8"
,
"fp8"
,
)
)
...
...
transformer_engine/debug/features/log_tensor_stats.py
View file @
063ef88d
...
@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
...
@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
Quantizer
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
transformer_engine.pytorch.tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
transformer_engine.pytorch.tensor.
_internal
.mxfp8_tensor_
bas
e
import
MXFP8Tensor
Bas
e
from
transformer_engine.pytorch.tensor.
storage
.mxfp8_tensor_
storag
e
import
MXFP8Tensor
Storag
e
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils
import
next_enabled_iter
,
get_reduction_params
from
transformer_engine.debug.features.utils
import
next_enabled_iter
,
get_reduction_params
...
@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats):
...
@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert
(
assert
(
type
(
tensor
)
not
in
[
Float8Tensor
,
Float8Tensor
Bas
e
,
MXFP8Tensor
,
MXFP8Tensor
Bas
e
]
type
(
tensor
)
not
in
[
Float8Tensor
,
Float8Tensor
Storag
e
,
MXFP8Tensor
,
MXFP8Tensor
Storag
e
]
and
tensor
.
dtype
!=
torch
.
uint8
and
tensor
.
dtype
!=
torch
.
uint8
),
(
),
(
f
"[NVTORCH INSPECT ERROR] Tensor
{
tensor_name
}
must be in high precision when using"
f
"[NVTORCH INSPECT ERROR] Tensor
{
tensor_name
}
must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
)
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
options
=
(
config
.
get
(
"
start_step
"
,
None
)
,
start_step
,
config
.
get
(
"end_step"
,
None
)
,
end_step
,
config
.
get
(
"
start_end_list
"
,
None
)
,
start_end_list
,
)
)
skip_reduction
,
reduction_group
,
reduce_within_microbatch
=
get_reduction_params
(
skip_reduction
,
reduction_group
,
reduce_within_microbatch
=
get_reduction_params
(
...
...
transformer_engine/debug/features/utils/stats_buffer.py
View file @
063ef88d
...
@@ -172,11 +172,19 @@ class StatsBuffers:
...
@@ -172,11 +172,19 @@ class StatsBuffers:
if
self
.
at_least_one_layer_fed
:
if
self
.
at_least_one_layer_fed
:
return
True
return
True
iteration
=
TEDebugState
.
get_iteration
()
iteration
=
TEDebugState
.
get_iteration
()
for
_
,
next_iter
in
self
.
layers_to_next_iter
.
items
():
layers_to_remove
=
[]
for
layer_name
,
next_iter
in
self
.
layers_to_next_iter
.
items
():
# When next_iter is None the feature will no longer run.
if
next_iter
is
None
:
layers_to_remove
.
append
(
layer_name
)
continue
# Note that layer can be not run for many iterations,
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
# in this case we will synchronize until every step until we get any information from it.
if
iteration
>=
next_iter
:
if
iteration
>=
next_iter
:
return
True
return
True
for
layer_name
in
layers_to_remove
:
self
.
layers_to_next_iter
.
pop
(
layer_name
,
None
)
return
False
return
False
def
reset
(
self
):
def
reset
(
self
):
...
...
transformer_engine/debug/pytorch/debug_quantization.py
View file @
063ef88d
...
@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
...
@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
Quantizer
,
Quantizer
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
...
@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
...
@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
self
.
_update_parent_quantizer_usage
()
self
.
_update_parent_quantizer_usage
()
class
DebugQuantizedTensor
(
QuantizedTensor
Bas
e
):
class
DebugQuantizedTensor
(
QuantizedTensor
Storag
e
):
"""
"""
Class containing quantized tensors after debug. Depending on configuration
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
it can contain one or two different objects. These objects can be accessed by the method
...
...
transformer_engine/jax/__init__.py
View file @
063ef88d
...
@@ -34,7 +34,7 @@ load_framework_extension("jax")
...
@@ -34,7 +34,7 @@ load_framework_extension("jax")
from
.
import
flax
from
.
import
flax
from
.
import
quantize
from
.
import
quantize
from
.quantize
import
fp8_autocast
,
update_collections
,
get_delayed_scaling
from
.quantize
import
autocast
,
fp8_autocast
,
update_collections
from
.quantize
import
NVTE_FP8_COLLECTION_NAME
from
.quantize
import
NVTE_FP8_COLLECTION_NAME
from
.sharding
import
MeshResource
from
.sharding
import
MeshResource
...
@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum
...
@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum
__all__
=
[
__all__
=
[
"NVTE_FP8_COLLECTION_NAME"
,
"NVTE_FP8_COLLECTION_NAME"
,
"autocast"
,
"fp8_autocast"
,
"fp8_autocast"
,
"update_collections"
,
"update_collections"
,
"get_delayed_scaling"
,
"MeshResource"
,
"MeshResource"
,
"flax"
,
"flax"
,
"quantize"
,
"quantize"
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment