Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
53fa872c
Commit
53fa872c
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_release_v2.8' into release_v2.8
parents
27ddce40
40c69e75
Changes
159
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4256 additions
and
259 deletions
+4256
-259
transformer_engine/common/hadamard_transform/hadamard_transform.cu
...er_engine/common/hadamard_transform/hadamard_transform.cu
+876
-0
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+841
-0
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+25
-0
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+73
-37
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+182
-7
transformer_engine/common/include/transformer_engine/hadamard_transform.h
...ne/common/include/transformer_engine/hadamard_transform.h
+68
-0
transformer_engine/common/include/transformer_engine/recipe.h
...sformer_engine/common/include/transformer_engine/recipe.h
+4
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+46
-4
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+4
-4
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+4
-4
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+115
-10
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+19
-8
transformer_engine/common/recipe/nvfp4.cu
transformer_engine/common/recipe/nvfp4.cu
+54
-0
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+153
-120
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+79
-7
transformer_engine/common/transpose/cast_transpose.h
transformer_engine/common/transpose/cast_transpose.h
+9
-0
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+842
-0
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+4
-1
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+761
-46
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+97
-11
No files found.
transformer_engine/common/hadamard_transform/hadamard_transform.cu
0 → 100644
View file @
53fa872c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
{
constexpr
int
kThreadsPerWarp
=
32
;
constexpr
float
k16x16HadamardScale
=
0.25
f
;
template
<
bool
kTranspose
>
__device__
__forceinline__
void
ldmatrix_x4_m8n8_shared_b16
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
)
{
auto
smem_addr
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
addr
));
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"r"
(
smem_addr
));
}
else
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"r"
(
smem_addr
));
}
}
template
<
bool
kTranspose
>
__device__
__forceinline__
void
load_matrix_16x16_from_shared
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
,
uint32_t
stride
)
{
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"l"
(
addr
),
"r"
(
stride
));
}
else
{
asm
volatile
(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;
\n
"
:
"=r"
(
a0
),
"=r"
(
a1
),
"=r"
(
a2
),
"=r"
(
a3
)
:
"l"
(
addr
),
"r"
(
stride
));
}
}
template
<
bool
kTranspose
>
__device__
__forceinline__
void
store_matrix_16x16_to_global
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
void
*
addr
,
uint32_t
stride
)
{
if
constexpr
(
kTranspose
)
{
asm
volatile
(
"wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;
\n
"
:
:
"l"
(
addr
),
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
stride
));
}
else
{
asm
volatile
(
"wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;
\n
"
:
:
"l"
(
addr
),
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
stride
));
}
}
__device__
__forceinline__
void
matrix_transpose_m8_n8_b16_inplace
(
uint32_t
&
a0
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;
\n\t
"
:
"=r"
(
a0
)
:
"r"
(
a0
));
}
__device__
__forceinline__
void
unpack_max_of_packed_bf16
(
uint32_t
&
packed_bf16
,
float
&
float_dst
)
{
__nv_bfloat162
bf16x2
=
*
reinterpret_cast
<
__nv_bfloat162
*>
(
&
packed_bf16
);
float
f_a
=
__bfloat162float
(
bf16x2
.
x
);
float
f_b
=
__bfloat162float
(
bf16x2
.
y
);
asm
volatile
(
"max.xorsign.abs.f32 %0, %1, %2;
\n\t
"
:
"=f"
(
float_dst
)
:
"f"
(
f_a
),
"f"
(
f_b
));
float_dst
=
fabsf
(
float_dst
);
}
template
<
bool
kCalculateAmax
>
__device__
__forceinline__
void
mma_m16_n16_k16_b16_b16_b16_noacc
(
uint32_t
&
a0
,
uint32_t
&
a1
,
uint32_t
&
a2
,
uint32_t
&
a3
,
uint32_t
&
b0
,
uint32_t
&
b1
,
uint32_t
&
b2
,
uint32_t
&
b3
,
uint32_t
&
c0
,
uint32_t
&
c1
,
uint32_t
&
c2
,
uint32_t
&
c3
,
uint32_t
&
amax_result
)
{
uint32_t
zero
=
0
;
uint32_t
temp0
,
temp1
,
temp2
,
temp3
,
temp4
,
temp5
,
temp6
,
temp7
;
asm
volatile
(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32
\n
"
"{%0, %1, %2, %3, %4, %5, %6, %7},
\n
"
"{%8, %9, %10, %11},
\n
"
"{%12, %13, %14, %15},
\n
"
"{%16, %17, %18, %19, %20, %21, %22, %23};
\n\t
"
:
"=r"
(
temp0
),
"=r"
(
temp1
),
"=r"
(
temp2
),
"=r"
(
temp3
),
"=r"
(
temp4
),
"=r"
(
temp5
),
"=r"
(
temp6
),
"=r"
(
temp7
)
:
"r"
(
a0
),
"r"
(
a1
),
"r"
(
a2
),
"r"
(
a3
),
"r"
(
b0
),
"r"
(
b1
),
"r"
(
b2
),
"r"
(
b3
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
),
"r"
(
zero
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c0
)
:
"r"
(
temp1
),
"r"
(
temp0
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c1
)
:
"r"
(
temp3
),
"r"
(
temp2
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c2
)
:
"r"
(
temp5
),
"r"
(
temp4
));
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
c3
)
:
"r"
(
temp7
),
"r"
(
temp6
));
if
constexpr
(
kCalculateAmax
)
{
uint32_t
max_even
;
uint32_t
max_odd
;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
max_even
)
:
"r"
(
c0
),
"r"
(
c2
));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
max_odd
)
:
"r"
(
c1
),
"r"
(
c3
));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
amax_result
)
:
"r"
(
max_even
),
"r"
(
max_odd
));
}
}
template
<
bool
kReturnIdentity
,
bool
kReturnTransposed
,
bool
kInverseHadamardIdentity
,
bool
kInverseHadamardTransposed
>
__device__
__forceinline__
void
get_hadamard_matrix_fragment
(
uint32_t
*
had_frag_i
,
uint16_t
random_sign_mask
,
uint32_t
*
had_frag_t
,
uint16_t
random_sign_mask_t
)
{
int32_t
tid
=
threadIdx
.
x
%
32
;
// Local tid
float
temp_i
[
2
];
float
temp_t
[
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t
r
=
i
*
8
+
tid
/
4
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
2
;
k
++
)
{
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t
c
=
j
*
8
+
k
+
tid
%
4
*
2
;
// 1 -> -1.0f, 0 -> 1.0f
int32_t
base_sign
=
__popc
(
r
&
c
);
if
constexpr
(
kReturnIdentity
)
{
int32_t
sign_i
;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if
constexpr
(
kInverseHadamardIdentity
)
{
sign_i
=
((
random_sign_mask
>>
r
)
^
base_sign
);
}
else
{
sign_i
=
((
random_sign_mask
>>
c
)
^
base_sign
);
}
temp_i
[
k
]
=
copysignf
(
k16x16HadamardScale
,
__int_as_float
(
sign_i
<<
31
));
}
if
constexpr
(
kReturnTransposed
)
{
int32_t
sign_t
;
if
constexpr
(
kInverseHadamardTransposed
)
{
sign_t
=
((
random_sign_mask_t
>>
r
)
^
base_sign
);
}
else
{
sign_t
=
((
random_sign_mask_t
>>
c
)
^
base_sign
);
}
temp_t
[
k
]
=
copysignf
(
k16x16HadamardScale
,
__int_as_float
(
sign_t
<<
31
));
}
}
if
constexpr
(
kReturnIdentity
)
{
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
had_frag_i
[
i
*
2
+
j
])
:
"f"
(
temp_i
[
1
]),
"f"
(
temp_i
[
0
]));
}
if
constexpr
(
kReturnTransposed
)
{
asm
volatile
(
"cvt.rn.bf16x2.f32 %0, %1, %2;
\n\t
"
:
"=r"
(
had_frag_t
[
i
*
2
+
j
])
:
"f"
(
temp_t
[
1
]),
"f"
(
temp_t
[
0
]));
}
}
}
}
__device__
__forceinline__
uint32_t
swizzle_128B_atom_32B
(
uint32_t
gmem_row_idx
,
uint32_t
gmem_col_idx
)
{
uint32_t
smem_row_idx
=
gmem_row_idx
;
uint32_t
xor_factor
=
(
smem_row_idx
*
2
)
%
8
;
uint32_t
smem_col_idx
=
gmem_col_idx
^
xor_factor
;
return
smem_row_idx
*
8
+
smem_col_idx
;
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ComputeKernel
(
uint32_t
b_frag_i
[
4
],
uint32_t
b_frag_t
[
4
],
IType
*
in_sh_ptr
,
uint32_t
&
local_pre_rht_amax_reg
,
uint32_t
&
local_amax_reg
,
uint32_t
&
local_amax_t_reg
)
{
uint32_t
a_frag
[
4
];
// A matrix fragment
uint32_t
c_frag
[
4
];
// Result fragment
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
local_rank
=
(
threadIdx
.
x
%
kThreadsPerWarp
);
int
ld_row_idx
=
local_rank
%
kHadamardDimension
;
int
ld_col_idx
=
local_rank
/
kHadamardDimension
+
warp_id
*
2
;
int
swizzle_idx
=
swizzle_128B_atom_32B
(
ld_row_idx
,
ld_col_idx
);
uint32_t
temp_amax_reg
;
uint32_t
temp_amax_t_reg
;
if
(
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnIdentityAmax
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
b_frag_i
[
0
],
b_frag_i
[
1
],
b_frag_i
[
2
],
b_frag_i
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_reg
)
:
"r"
(
local_amax_reg
),
"r"
(
temp_amax_reg
));
}
if
(
kReturnTransposedAmax
)
{
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if
(
!
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
0
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
1
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
2
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
3
]);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnTransposedAmax
>
(
a_frag
[
0
],
a_frag
[
2
],
a_frag
[
1
],
a_frag
[
3
],
b_frag_t
[
0
],
b_frag_t
[
1
],
b_frag_t
[
2
],
b_frag_t
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_t_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_t_reg
)
:
"r"
(
local_amax_t_reg
),
"r"
(
temp_amax_t_reg
));
}
if
(
kReturnPreRhtAmax
)
{
if
(
!
kReturnIdentityAmax
&&
!
kReturnTransposedAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
1
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
2
])
:
"r"
(
a_frag
[
2
]),
"r"
(
a_frag
[
3
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
2
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_pre_rht_amax_reg
)
:
"r"
(
a_frag
[
0
]),
"r"
(
local_pre_rht_amax_reg
));
}
}
template
<
int
kN
>
__device__
__host__
constexpr
int
NextPowerOf2
()
{
static_assert
(
kN
>
0
,
"kN must be > 0"
);
// Round up to the next power of 2 by counting leading zeros.
return
1
<<
(
32
-
__builtin_clz
(
kN
-
1
));
}
template
<
int
kNumWarps
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ReduceMax
(
const
float
pre_rht_amax
,
const
float
identity_amax
,
const
float
transpose_amax
,
float
*
staging_for_pre_rht
,
float
*
staging_for_identity
,
float
*
staging_for_transpose
,
float
*
output_pre_rht_amax_ptr
,
float
*
output_identity_amax_ptr
,
float
*
output_transpose_amax_ptr
,
const
int
warpid
)
{
// intra-warp reduction
constexpr
int
kWarpSize
=
32
;
int
local_rank
=
threadIdx
.
x
%
32
;
float
warp_pre_rht_amax
=
kReturnPreRhtAmax
?
warp_reduce_max
<
kWarpSize
>
(
pre_rht_amax
)
:
0.0
f
;
float
warp_identity_amax
=
kReturnIdentityAmax
?
warp_reduce_max
<
kWarpSize
>
(
identity_amax
)
:
0.0
f
;
float
warp_transpose_amax
=
kReturnTransposedAmax
?
warp_reduce_max
<
kWarpSize
>
(
transpose_amax
)
:
0.0
f
;
// inter-warp reduction
if
(
threadIdx
.
x
%
32
==
0
)
{
if
(
kReturnPreRhtAmax
)
{
staging_for_pre_rht
[
warpid
]
=
warp_pre_rht_amax
;
}
if
(
kReturnIdentityAmax
)
{
staging_for_identity
[
warpid
]
=
warp_identity_amax
;
}
if
(
kReturnTransposedAmax
)
{
staging_for_transpose
[
warpid
]
=
warp_transpose_amax
;
}
}
__syncthreads
();
constexpr
int
kNumWarpsPow2
=
NextPowerOf2
<
kNumWarps
>
();
if
(
warpid
==
0
)
{
if
(
kReturnIdentityAmax
)
{
float
identity_accum
=
local_rank
<
kNumWarps
?
staging_for_identity
[
local_rank
]
:
0.0
f
;
identity_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
identity_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_identity_amax_ptr
,
identity_accum
);
}
}
}
if
(
warpid
==
1
)
{
if
(
kReturnTransposedAmax
)
{
float
transpose_accum
=
local_rank
<
kNumWarps
?
staging_for_transpose
[
local_rank
]
:
0.0
f
;
transpose_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
transpose_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_transpose_amax_ptr
,
transpose_accum
);
}
}
}
if
(
warpid
==
2
)
{
if
(
kReturnPreRhtAmax
)
{
float
pre_rht_accum
=
local_rank
<
kNumWarps
?
staging_for_pre_rht
[
local_rank
]
:
0.0
f
;
pre_rht_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
pre_rht_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_pre_rht_amax_ptr
,
pre_rht_accum
);
}
}
}
}
__launch_bounds__
(
1
)
__global__
void
ZeroAmaxKernel
(
float
*
__restrict__
output_pre_rht_amax_ptr
,
float
*
__restrict__
output_identity_amax_ptr
,
float
*
__restrict__
output_transpose_amax_ptr
)
{
if
(
output_pre_rht_amax_ptr
!=
nullptr
)
{
*
output_pre_rht_amax_ptr
=
0
;
}
if
(
output_identity_amax_ptr
!=
nullptr
)
{
*
output_identity_amax_ptr
=
0
;
}
if
(
output_transpose_amax_ptr
!=
nullptr
)
{
*
output_transpose_amax_ptr
=
0
;
}
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
CHUNK_DIM_Y
,
int
CHUNK_DIM_X
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
int
THREADS_PER_CHUNK
,
int
THREADS_PER_Y
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__global__
void
HadamardAmaxTmaKernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
float
*
__restrict__
output_pre_rht_amax_ptr
,
float
*
__restrict__
output_identity_amax_ptr
,
float
*
__restrict__
output_transpose_amax_ptr
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
uint64_t
num_rows
,
uint64_t
row_length
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
&&
CHUNK_DIM_Y
%
BUFF_DIM_Y
==
0
);
static_assert
(
CHUNK_DIM_X
>=
BUFF_DIM_X
&&
CHUNK_DIM_X
%
BUFF_DIM_X
==
0
);
constexpr
size_t
STAGES_Y
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
STAGES_X
=
CHUNK_DIM_X
/
BUFF_DIM_X
;
constexpr
int
kNumWarps
=
(
THREADS_PER_CHUNK
*
THREADS_PER_Y
)
/
kThreadsPerWarp
;
const
int
input_block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
input_block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
extern
__shared__
__align__
(
128
)
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t
*
dshmem
=
reinterpret_cast
<
uint8_t
*>
((
base_shmem_ptr
+
127
)
&
~
127ULL
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr
size_t
in_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
IType
*
in_sh_0
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_sh_1
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_shs
[
2
]
=
{
in_sh_0
,
in_sh_1
};
constexpr
int
shmem_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t
*
mbar
=
reinterpret_cast
<
uint64_t
*>
(
dshmem
);
dshmem
+=
sizeof
(
uint64_t
)
*
(
STAGES_X
*
STAGES_Y
);
float
*
max_staging_identity
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_transpose
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_pre_rht
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
initialize_barriers
<
STAGES_X
*
STAGES_Y
,
THREADS_PER_CHUNK
*
THREADS_PER_Y
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
in_shs
[
0
],
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_block_offset_X
,
input_block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
uint32_t
had_frag_i
[
4
];
uint32_t
had_frag_t
[
4
];
get_hadamard_matrix_fragment
<
kReturnIdentityAmax
,
kReturnTransposedAmax
,
false
,
false
>
(
had_frag_i
,
random_sign_mask
,
had_frag_t
,
random_sign_mask_t
);
float
local_pre_rht_amax
=
0.0
;
float
local_amax
=
0.0
;
float
local_amax_t
=
0.0
;
uint32_t
local_pre_rht_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_pre_rht_amax
);
uint32_t
local_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax
);
uint32_t
local_amax_t_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax_t
);
for
(
int
stage_y
=
0
;
stage_y
<
STAGES_Y
;
++
stage_y
)
{
for
(
int
stage_x
=
0
;
stage_x
<
STAGES_X
;
++
stage_x
)
{
int
stage
=
STAGES_X
*
stage_y
+
stage_x
;
const
int
next_stage
=
stage
+
1
;
const
int
next_stage_x
=
stage_x
+
1
==
STAGES_X
?
0
:
stage_x
+
1
;
const
int
next_stage_y
=
stage_x
+
1
==
STAGES_X
?
stage_y
+
1
:
stage_y
;
if
(
next_stage
<
STAGES_X
*
STAGES_Y
)
{
const
int
input_global_offset_Y
=
input_block_offset_Y
+
next_stage_y
*
BUFF_DIM_Y
;
const
int
input_global_offset_X
=
input_block_offset_X
+
next_stage_x
*
BUFF_DIM_X
;
copy_2d_to_shared
(
in_shs
[
next_stage
%
2
],
// ping-pong
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_global_offset_X
,
input_global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
const
size_t
compute_stage_x_num
=
BUFF_DIM_X
/
(
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
));
const
size_t
compute_stage_y_num
=
BUFF_DIM_Y
/
(
kHadamardDimension
*
THREADS_PER_Y
);
const
size_t
in_row_stride
=
BUFF_DIM_X
;
IType
*
in_sh_ptr
=
in_shs
[
stage
%
2
];
#pragma unroll
for
(
size_t
compute_stage_y
=
0
;
compute_stage_y
<
compute_stage_y_num
;
compute_stage_y
++
)
{
const
int
row_idx_offset
=
(
compute_stage_y
*
kHadamardDimension
*
THREADS_PER_Y
+
threadIdx
.
y
*
kHadamardDimension
);
const
int
in_row_offset
=
row_idx_offset
*
in_row_stride
;
#pragma unroll
for
(
size_t
compute_stage_x
=
0
;
compute_stage_x
<
compute_stage_x_num
;
compute_stage_x
++
)
{
ComputeKernel
<
IType
,
kHadamardDimension
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
had_frag_i
,
had_frag_t
,
in_sh_ptr
+
in_row_offset
+
(
compute_stage_x
*
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
)),
local_pre_rht_amax_reg
,
local_amax_reg
,
local_amax_t_reg
);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads
();
}
}
}
const
int
warpid
=
(
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
)
/
kThreadsPerWarp
;
if
constexpr
(
kReturnPreRhtAmax
)
{
unpack_max_of_packed_bf16
(
local_pre_rht_amax_reg
,
local_pre_rht_amax
);
}
if
constexpr
(
kReturnIdentityAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_reg
,
local_amax
);
}
if
constexpr
(
kReturnTransposedAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_t_reg
,
local_amax_t
);
}
ReduceMax
<
kNumWarps
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
local_pre_rht_amax
,
local_amax
,
local_amax_t
,
max_staging_pre_rht
,
max_staging_identity
,
max_staging_transpose
,
output_pre_rht_amax_ptr
,
output_identity_amax_ptr
,
output_transpose_amax_ptr
,
warpid
);
destroy_barriers
<
STAGES_X
*
STAGES_Y
>
(
mbar
,
is_master_thread
);
#else
NVTE_DEVICE_ERROR
(
"Kernel is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template
<
typename
T
,
int
kHadamardDimension
,
bool
kComputeIdentity
,
bool
kComputeTransposed
,
bool
kReturnIdentity
,
bool
kReturnTransposed
,
bool
kUpdateIdentityAmax
,
bool
kUpdateTransposeAmax
,
bool
kOutputTrueTransposed
>
__global__
void
HadamardTransformKernel
(
const
T
*
__restrict__
input
,
T
*
__restrict__
output
,
T
*
__restrict__
output_t
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
uint64_t
num_input_rows
,
uint64_t
num_input_cols
,
float
*
__restrict__
amax
,
float
*
__restrict__
amax_t
,
bool
inverse_hadamard
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
static_assert
(
kHadamardDimension
==
16
,
"Currently only hadamard dimension 16 is supported."
);
// The whole threadblock will share the same smem.
extern
__shared__
__align__
(
16
)
T
smem
[];
// Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16.
// If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices.
int32_t
tid
=
threadIdx
.
x
;
int32_t
warp_id
=
threadIdx
.
y
*
blockDim
.
z
+
threadIdx
.
z
;
int32_t
local_bx
=
threadIdx
.
y
;
int32_t
local_by
=
threadIdx
.
z
;
// Define the register fragments
uint32_t
a_frag
[
4
];
// A matrix fragment
uint32_t
b_frag_i
[
4
];
// Transposed Hadamard matrix fragment, used for A @ B(col major)
uint32_t
b_frag_t
[
4
];
// Hadamard matrix fragment, used for A.T @ B.T(col major)
uint32_t
c_frag
[
4
];
// Result fragment
// row and col for each thread. 32 threads will work together in 128 chunk to
// load the data from global memory to shared memory.
uint32_t
row
=
tid
/
(
kHadamardDimension
*
sizeof
(
T
)
/
sizeof
(
uint4
));
uint32_t
col
=
tid
%
(
kHadamardDimension
*
sizeof
(
T
)
/
sizeof
(
uint4
));
uint32_t
smem_index
=
tid
;
uint32_t
input_start_col
=
(
blockIdx
.
x
*
blockDim
.
y
+
local_bx
)
*
kHadamardDimension
;
uint32_t
input_start_row
=
(
blockIdx
.
y
*
blockDim
.
z
+
local_by
)
*
kHadamardDimension
;
bool
load
=
(
input_start_col
<
num_input_cols
)
&&
(
input_start_row
<
num_input_rows
);
if
(
!
load
)
{
// Out of bound, we are returning early. No thread divergence since the whole warp
// will return early.
return
;
}
uint64_t
global_offset
=
input_start_col
+
input_start_row
*
num_input_cols
;
uint64_t
global_offset_t
=
kOutputTrueTransposed
?
(
input_start_row
+
input_start_col
*
num_input_rows
)
:
global_offset
;
T
*
base_smem
=
smem
+
kHadamardDimension
*
kHadamardDimension
*
warp_id
;
uint32_t
*
smem_b32
=
reinterpret_cast
<
uint32_t
*>
(
base_smem
);
uint4
*
smem_b128
=
reinterpret_cast
<
uint4
*>
(
base_smem
);
// Asynchronously load the data from global memory to shared memory.
const
uint4
*
input_b128
=
reinterpret_cast
<
const
uint4
*>
(
input
+
global_offset
);
// Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each
// 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4
// to load the data in the tensor core swizzled format.
__pipeline_memcpy_async
(
&
smem_b128
[
smem_index
],
&
input_b128
[
row
*
num_input_cols
/
(
sizeof
(
uint4
)
/
sizeof
(
T
))
+
col
],
sizeof
(
uint4
));
__pipeline_commit
();
// Commit the memcpy. Wait when we are in the computation.
if
(
inverse_hadamard
)
{
get_hadamard_matrix_fragment
<
kComputeIdentity
,
kComputeTransposed
,
/*kInverseHadamard=*/
true
,
/*kInverseHadamardTransposed=*/
true
>
(
b_frag_i
,
random_sign_mask
,
b_frag_t
,
random_sign_mask_t
);
}
else
{
get_hadamard_matrix_fragment
<
kComputeIdentity
,
kComputeTransposed
,
/*kInverseHadamard=*/
false
,
/*kInverseHadamardTransposed=*/
false
>
(
b_frag_i
,
random_sign_mask
,
b_frag_t
,
random_sign_mask_t
);
}
float
local_amax
=
0.0
;
float
local_amax_t
=
0.0
;
uint32_t
local_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax
);
uint32_t
local_amax_t_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax_t
);
__pipeline_wait_prior
(
0
);
__syncwarp
();
// ensure all lanes finished their cp.async before reading smem
// Load the A to a_frag.
if
constexpr
(
kComputeIdentity
)
{
load_matrix_16x16_from_shared
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
smem_b32
,
kHadamardDimension
);
// 16x16 @ 16x16 leveraging all threads in the warp.
mma_m16_n16_k16_b16_b16_b16_noacc
<
kUpdateIdentityAmax
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
b_frag_i
[
0
],
b_frag_i
[
1
],
b_frag_i
[
2
],
b_frag_i
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
local_amax_reg
);
// Store the result to the shared memory in non-transposed order.
if
constexpr
(
kReturnIdentity
)
{
uint4
*
output_b128
=
reinterpret_cast
<
uint4
*>
(
output
+
global_offset
);
store_matrix_16x16_to_global
<
false
>
(
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
output_b128
,
num_input_cols
);
}
}
if
constexpr
(
kComputeTransposed
)
{
if
(
kComputeIdentity
)
{
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
0
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
1
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
2
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
3
]);
}
else
{
load_matrix_16x16_from_shared
<
true
>
(
a_frag
[
0
],
a_frag
[
2
],
// NOTE: intentional index swapping
a_frag
[
1
],
// NOTE: intentional index swapping
a_frag
[
3
],
smem_b32
,
kHadamardDimension
);
}
mma_m16_n16_k16_b16_b16_b16_noacc
<
kUpdateTransposeAmax
>
(
a_frag
[
0
],
// 2,1 is used if we are using movmatrix instruction.
// Thus loading the matrix in 2,1 order will just be normal.
// This is to be compatible with the movmatrix instruction.
a_frag
[
2
],
// NOTE: intentional index swapping for transpose purpose.
a_frag
[
1
],
// NOTE: intentional index swapping for transpose purpose.
a_frag
[
3
],
b_frag_t
[
0
],
b_frag_t
[
1
],
b_frag_t
[
2
],
b_frag_t
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
local_amax_t_reg
);
// Store the result to the shared memory in non-transposed order.
if
constexpr
(
kReturnTransposed
)
{
uint4
*
output_t_b128
=
reinterpret_cast
<
uint4
*>
(
output_t
+
global_offset_t
);
store_matrix_16x16_to_global
<!
kOutputTrueTransposed
>
(
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
output_t_b128
,
kOutputTrueTransposed
?
num_input_rows
:
num_input_cols
);
}
}
if
constexpr
(
kUpdateIdentityAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_reg
,
local_amax
);
local_amax
=
warp_reduce_max
<
kThreadsPerWarp
>
(
local_amax
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
local_amax
=
__shfl_sync
(
0xFFFFFFFF
,
local_amax
,
lane_zero
);
// atomic CAS to output memory.
if
(
tid
%
kThreadsPerWarp
==
0
)
{
atomicMaxFloat
(
amax
,
local_amax
);
}
}
if
constexpr
(
kUpdateTransposeAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_t_reg
,
local_amax_t
);
local_amax_t
=
warp_reduce_max
<
kThreadsPerWarp
>
(
local_amax_t
);
// broadcast the amax to all threads in a warp from the lane 0
constexpr
int
lane_zero
=
0
;
local_amax_t
=
__shfl_sync
(
0xFFFFFFFF
,
local_amax_t
,
lane_zero
);
// atomic CAS to output memory.
if
(
tid
%
kThreadsPerWarp
==
0
)
{
atomicMaxFloat
(
amax_t
,
local_amax_t
);
}
}
#else
NVTE_DEVICE_ERROR
(
"Kernel is only supported on SM 9.0+."
);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
}
}
// namespace
void
hadamard_transform
(
const
Tensor
&
input_
,
Tensor
&
output_
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform
);
// Check tensors
// NOTE (frsun): This is non-intuitive, we are writing the result of
// transposed RHT to the output of rowwise.
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
NVTE_CHECK
(
output_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor must be simple tensor, but scaling mode is "
,
to_string
(
output_
.
scaling_mode
),
"."
);
const
SimpleTensor
&
input
=
input_
.
data
;
SimpleTensor
output
;
SimpleTensor
&
output_t
=
output_
.
data
;
// Check requested outputs
const
bool
return_identity
=
output
.
dptr
!=
nullptr
;
const
bool
return_transposed
=
output_t
.
dptr
!=
nullptr
;
if
(
!
return_identity
&&
!
return_transposed
)
{
// Nothing to do/ill-defined behavior.
return
;
}
checkCuDriverContext
(
stream
);
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
row_length
=
input
.
shape
[
ndim
-
1
];
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
num_rows
*=
input
.
shape
[
i
];
}
using
IType
=
bf16
;
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
row_length
%
kHadamardDimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
num_rows
%
kHadamardDimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
constexpr
uint64_t
kThreadBlockX
=
4
;
// Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth.
constexpr
uint64_t
kThreadBlockY
=
4
;
uint64_t
kNumWarpsPerSM
=
kThreadBlockX
*
kThreadBlockY
;
// The shared memory number of bytes required for **the whole threadblock**.
size_t
shmem_bytes
=
kHadamardDimension
*
kHadamardDimension
*
sizeof
(
IType
)
*
kNumWarpsPerSM
;
dim3
block
(
kThreadsPerWarp
,
kThreadBlockX
,
kThreadBlockY
);
dim3
grid
(
DIVUP
(
row_length
/
kHadamardDimension
,
kThreadBlockX
),
DIVUP
(
num_rows
/
kHadamardDimension
,
kThreadBlockY
));
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transposed
,
kReturnTransposed
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity
,
kReturnIdentity
,
auto
kernel
=
HadamardTransformKernel
<
IType
,
kHadamardDimension
,
kReturnIdentity
,
kReturnTransposed
,
kReturnIdentity
,
kReturnTransposed
,
false
,
false
,
true
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_bytes
);
kernel
<<<
grid
,
block
,
shmem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
IType
*>
(
input
.
dptr
),
reinterpret_cast
<
IType
*>
(
output
.
dptr
),
reinterpret_cast
<
IType
*>
(
output_t
.
dptr
),
random_sign_mask
,
random_sign_mask_t
,
num_rows
,
row_length
,
nullptr
,
nullptr
,
false
);););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then
// get the absolute max value of the result.
void
hadamard_transform_amax
(
const
Tensor
&
input_
,
Tensor
&
output_
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform_amax
);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
const
SimpleTensor
&
input
=
input_
.
data
;
// Check amax tensors
SimpleTensor
&
output_pre_rht_amax
=
output_
.
amax
;
SimpleTensor
output_identity_amax
;
SimpleTensor
&
output_transpose_amax
=
output_
.
columnwise_amax
;
// Check requested outputs
const
bool
return_pre_rht_amax
=
output_pre_rht_amax
.
dptr
!=
nullptr
;
const
bool
return_identity_amax
=
output_identity_amax
.
dptr
!=
nullptr
;
const
bool
return_transposed_amax
=
output_transpose_amax
.
dptr
!=
nullptr
;
if
(
!
return_identity_amax
&&
!
return_transposed_amax
&&
!
return_pre_rht_amax
)
{
// Nothing to do/ill-defined behavior.
return
;
}
// Zero out amaxes if needed
ZeroAmaxKernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output_pre_rht_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_identity_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_transpose_amax
.
dptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
checkCuDriverContext
(
stream
);
using
IType
=
bf16
;
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
row_length
=
input
.
shape
[
ndim
-
1
];
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
num_rows
*=
input
.
shape
[
i
];
}
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
row_length
%
kHadamardDimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
num_rows
%
kHadamardDimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
constexpr
uint64_t
kChunkBlockXSmall
=
128
;
constexpr
uint64_t
kChunkBlockYSmall
=
128
;
constexpr
uint64_t
kBuffDimX
=
64
;
constexpr
uint64_t
kBuffDimY
=
64
;
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
create_2D_tensor_map
(
/*tensorMap=*/
tensor_map_input
,
/*tensor=*/
input
,
/*globalY=*/
num_rows
,
/*globalX=*/
row_length
,
/*shmemY=*/
kBuffDimY
,
/*shmemX=*/
kBuffDimX
,
/*stride_elems=*/
row_length
,
/*offset_elems=*/
0
,
/*type_num_bits=*/
sizeof
(
IType
)
*
8
,
/*swizzle=*/
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B
);
constexpr
uint64_t
kThreadBlockX
=
4
;
constexpr
uint64_t
kThreadBlockY
=
1
;
constexpr
uint64_t
kNumWarps
=
kThreadBlockX
*
kThreadBlockY
;
dim3
block
(
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
);
dim3
grid
(
DIVUP
(
row_length
,
kChunkBlockXSmall
),
DIVUP
(
num_rows
,
kChunkBlockYSmall
));
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transposed_amax
,
kReturnTransposedAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity_amax
,
kReturnIdentityAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_pre_rht_amax
,
kReturnPreRhtAmax
,
// *2 for ping-pong
size_t
in_sh_size
=
kBuffDimX
*
kBuffDimY
*
2
*
sizeof
(
IType
);
size_t
mbar_size
=
sizeof
(
uint64_t
)
*
(
kChunkBlockXSmall
/
kBuffDimX
)
*
(
kChunkBlockYSmall
/
kBuffDimY
);
size_t
shmem_bytes
=
in_sh_size
+
mbar_size
+
kNumWarps
*
sizeof
(
float
)
*
3
;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes
=
(
shmem_bytes
+
128
);
auto
kernel
=
HadamardAmaxTmaKernel
<
IType
,
kHadamardDimension
,
kChunkBlockYSmall
,
kChunkBlockXSmall
,
kBuffDimY
,
kBuffDimX
,
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_bytes
);
kernel
<<<
grid
,
block
,
shmem_bytes
,
stream
>>>
(
tensor_map_input
,
reinterpret_cast
<
float
*>
(
output_pre_rht_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_identity_amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output_transpose_amax
.
dptr
),
random_sign_mask
,
random_sign_mask_t
,
num_rows
,
row_length
);)));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace transformer_engine
void
nvte_hadamard_transform
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform
);
using
namespace
transformer_engine
;
hadamard_transform
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
static_cast
<
uint16_t
>
(
random_sign_mask
),
static_cast
<
uint16_t
>
(
random_sign_mask_t
),
stream
);
}
void
nvte_hadamard_transform_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform_amax
);
using
namespace
transformer_engine
;
hadamard_transform_amax
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
static_cast
<
uint16_t
>
(
random_sign_mask
),
static_cast
<
uint16_t
>
(
random_sign_mask_t
),
stream
);
}
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
0 → 100644
View file @
53fa872c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
namespace
transformer_engine
{
namespace
detail
{
namespace
{
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
using
namespace
cute
;
using
cute
::
Tensor
;
// Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
// calculate the global encode scale factor for a given global amax.
__device__
__forceinline__
float
ComputeGlobalEncodeScaleFP4
(
const
float
global_amax
)
{
constexpr
float
kFP8E4M3Max
=
448.0
f
;
constexpr
float
kFP4E2M1Max
=
6.0
f
;
// If scale is infinity, return max value of float32
float
global_encode_scale
=
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
kFP8E4M3Max
*
kFP4E2M1Max
/
global_amax
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
());
// If global amax is 0 or infinity, return 1
return
(
global_amax
==
0.
f
||
global_encode_scale
==
0.
f
)
?
1.
f
:
global_encode_scale
;
}
template
<
class
ElementA
,
class
ElementB
,
class
ASmemLayout
,
class
BSmemLayout
>
struct
SharedStorage
{
static
constexpr
int
AccumulatorPipelineStageCount
=
16
;
using
AtomThrShapeMNK
=
cute
::
Shape
<
_1
,
_1
,
_1
>
;
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
4
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineStorage
=
typename
AccumulatorPipeline
::
SharedStorage
;
static
constexpr
int
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
using
MainloopPipeline
=
cutlass
::
PipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
Shape
<
_1
,
_1
,
_1
>
,
AtomThrShapeMNK
>
;
using
MainloopPipelineStorage
=
typename
MainloopPipeline
::
SharedStorage
;
alignas
(
16
)
AccumulatorPipelineStorage
accumulator
;
alignas
(
16
)
MainloopPipelineStorage
mainloop
;
alignas
(
16
)
cute
::
uint64_t
tma_barrier
[
1
];
uint32_t
tmem_base_ptr
;
struct
TensorStorage
:
cute
::
aligned_struct
<
128
,
_1
>
{
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute
::
array_aligned
<
ElementA
,
cute
::
cosize_v
<
ASmemLayout
>>
smem_A
;
cute
::
array_aligned
<
ElementB
,
cute
::
cosize_v
<
BSmemLayout
>>
smem_B
;
}
tensors
;
};
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
asm
volatile
(
\
"{
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;
\n
"
\
"}"
\
:
"=h"
(
output_ptr
[
0
]),
"=h"
(
output_ptr
[
1
])
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
output
;
}
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
StochasticNumericConverter
(
cutlass
::
Array
<
float
,
16
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
4
>
const
*
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
;
result_type
output
;
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*
result_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*>
(
&
output
);
cutlass
::
Array
<
float
,
8
>
const
*
source_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
float
,
8
>
const
*>
(
&
input
);
cutlass
::
Array
<
uint32_t
,
2
>
const
*
rbits_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
2
>
const
*>
(
rbits
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
result_ptr
[
i
]
=
StochasticNumericConverterBase
(
source_ptr
[
i
],
rbits_ptr
[
i
]);
}
return
output
;
}
template
<
class
MShape
,
class
NShape
,
class
KShape
,
class
ClusterTileShape
,
class
TA
,
class
AStride
,
class
ASmemLayout
,
class
TmaLoadA
,
class
TB
,
class
BStride
,
class
BSmemLayout
,
class
TmaLoadB
,
class
TC
,
class
CStride
,
class
CSmemLayout
,
class
TSFC
,
class
TiledMMA
,
bool
kEnableStochasticRounding
=
false
>
__global__
static
void
rht_gemm_device
(
MShape
M
,
NShape
N
,
KShape
K
,
ClusterTileShape
cluster_tile
,
TA
const
*
A
,
AStride
dA
,
ASmemLayout
sAlayout
,
CUTE_GRID_CONSTANT
TmaLoadA
const
tma_load_a
,
TB
const
*
B
,
BStride
dB
,
BSmemLayout
sBlayout
,
CUTE_GRID_CONSTANT
TmaLoadB
const
tma_load_b
,
TC
*
C
,
CStride
dC
,
CSmemLayout
,
TSFC
*
SFC
,
TiledMMA
mma
,
float
const
*
global_amax
,
const
size_t
*
rng_state
)
{
using
namespace
cute
;
using
X
=
Underscore
;
// static constexpr bool kApplyStochasticRounding = true;
using
ElementAccumulator
=
float
;
static
constexpr
int
K_PIPE_MAX
=
size
<
3
>
(
ASmemLayout
{});
using
AtomThrShapeMNK
=
Shape
<
decltype
(
shape
<
0
>
(
typename
TiledMMA
::
ThrLayoutVMNK
{})),
_1
,
_1
>
;
static
constexpr
uint32_t
kTmaTransactionBytes
=
cutlass
::
bits_to_bytes
(
size
(
AtomThrShapeMNK
{})
*
cosize
(
take
<
0
,
3
>
(
ASmemLayout
{}))
*
cute
::
sizeof_bits_v
<
TA
>
);
static
constexpr
int
kTmaRhtTensorTransactionBytes
=
cutlass
::
bits_to_bytes
(
16
*
16
*
cute
::
sizeof_bits_v
<
TB
>
);
static
constexpr
int
AccumulatorPipelineStageCount
=
16
;
static
constexpr
int
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
using
MainloopPipeline
=
cutlass
::
PipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
Shape
<
_1
,
_1
,
_1
>
,
AtomThrShapeMNK
>
;
using
MainloopPipelineState
=
typename
MainloopPipeline
::
PipelineState
;
using
TmemAllocator
=
cute
::
TMEM
::
Allocator1Sm
;
static
constexpr
int
VectorSize
=
16
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
// Preconditions
CUTE_STATIC_ASSERT
(
is_static
<
ASmemLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
BSmemLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
CSmemLayout
>::
value
);
// Represent the full tensors
Tensor
mA
=
tma_load_a
.
get_tma_tensor
(
make_shape
(
M
,
N
));
Tensor
mB
=
tma_load_b
.
get_tma_tensor
(
make_shape
(
16
,
16
));
Tensor
mC
=
make_tensor
(
cute
::
subbyte_iterator
<
TC
>
(
C
),
make_shape
(
M
,
N
),
dC
);
// (M,N)
auto
sfc_shape
=
make_shape
(
M
,
make_shape
(
make_shape
(
Int
<
16
>
{},
_4
{}),
N
/
64
)
);
auto
sfc_stride
=
make_stride
(
N
/
16
,
make_stride
(
make_stride
(
_0
{},
_1
{}),
_4
{}
)
);
auto
sfc_layout
=
make_layout
(
sfc_shape
,
sfc_stride
);
Tensor
mSFC
=
make_tensor
(
make_gmem_ptr
(
SFC
),
sfc_layout
);
auto
cluster_shape
=
Shape
<
_1
,
_1
,
_1
>
{};
// Get the appropriate blocks for this Cluster
dim3
cluster_coord_in_grid
=
cluster_id_in_grid
();
// Total number of k-tiles
const
int
K_TILE_MAX
=
min
(
N
,
K
)
/
64
;
uint32_t
tiles_in_m
=
(
M
+
size
<
0
>
(
cluster_tile
)
-
1
)
/
size
<
0
>
(
cluster_tile
);
uint32_t
tiles_in_n
=
(
N
+
64
-
1
)
/
64
;
uint32_t
linear_tile_idx
=
blockIdx
.
x
;
uint32_t
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
uint32_t
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
auto
mainloop_tiler
=
Shape
<
_128
,
_16
,
_64
>
{};
auto
epilogue_tiler
=
Shape
<
_128
,
_64
,
_64
>
{};
Tensor
gA_mk
=
local_tile
(
mA
,
mainloop_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
gB_nk
=
local_tile
(
mB
,
cluster_tile
,
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,k)
Tensor
gC_mn
=
local_tile
(
mC
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
Tensor
gSFC_mn
=
local_tile
(
mSFC
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
// Allocate SMEM
extern
__shared__
char
shared_memory
[];
using
SharedStorage
=
SharedStorage
<
TA
,
TB
,
ASmemLayout
,
BSmemLayout
>
;
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
shared_memory
);
Tensor
tCsA
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_A
.
data
()),
sAlayout
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCsB
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_B
.
data
()),
sBlayout
);
// (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
ThrMMA
thr_mma
=
mma
.
get_slice
(
block_rank_in_cluster
);
// blk idx
Tensor
tCgB
=
thr_mma
.
partition_B
(
gB_nk
);
// (MMA,MMA_N,MMA_K,k)
auto
mma_epilogue
=
make_tiled_mma
(
SM100_MMA_F16BF16_SS
<
TA
,
TB
,
ElementAccumulator
,
128
,
64
,
UMMA
::
Major
::
MN
,
UMMA
::
Major
::
MN
>
{},
Layout
<
Shape
<
_1
,
_1
>>
{});
ThrMMA
thr_mma_epilogue
=
mma_epilogue
.
get_slice
(
block_rank_in_cluster
);
using
TiledMmaEpilogue
=
decltype
(
mma_epilogue
);
Tensor
tCgA
=
thr_mma
.
partition_A
(
gA_mk
);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor
tCrA
=
thr_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_K,PIPE)
auto
acc_shape_mma
=
partition_shape_C
(
TiledMMA
{},
take
<
0
,
2
>
(
ClusterTileShape
{}));
auto
acc_shape_epilogue
=
partition_shape_C
(
TiledMmaEpilogue
{},
take
<
0
,
2
>
(
epilogue_tiler
));
auto
bulk_tmem_mma
=
TiledMMA
::
make_fragment_C
(
append
(
acc_shape_mma
,
Int
<
AccumulatorPipelineStageCount
>
{}));
auto
bulk_tmem_epilogue
=
TiledMmaEpilogue
::
make_fragment_C
(
append
(
acc_shape_epilogue
,
Int
<
AccumulatorPipelineStageCount
/
4
>
{}));
TmemAllocator
tmem_allocator
{};
cutlass
::
arch
::
NamedBarrier
tmem_allocation_result_barrier
(
32
+
128
,
cutlass
::
arch
::
ReservedNamedBarriers
::
TmemAllocBarrier
);
Layout
cta_layout_mnk
=
make_layout
(
cluster_shape
);
Layout
cta_layout_vmnk
=
tiled_divide
(
cta_layout_mnk
,
make_tile
(
typename
TiledMMA
::
AtomThrID
{}));
auto
cta_coord_vmnk
=
cta_layout_vmnk
.
get_flat_coord
(
block_rank_in_cluster
);
auto
[
tAgA
,
tAsA
]
=
tma_partition
(
tma_load_a
,
get
<
2
>
(
cta_coord_vmnk
),
make_layout
(
size
<
2
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsA
),
group_modes
<
0
,
3
>
(
tCgA
));
auto
[
tBgB
,
tBsB
]
=
tma_partition
(
tma_load_b
,
get
<
1
>
(
cta_coord_vmnk
),
make_layout
(
size
<
1
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsB
),
group_modes
<
0
,
3
>
(
tCgB
));
uint16_t
tma_mcast_mask_a
=
create_tma_multicast_mask
<
2
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
uint16_t
tma_mcast_mask_b
=
create_tma_multicast_mask
<
1
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
bool
is_mma_warp
=
(
warp_idx
==
0
);
bool
is_dma_warp
=
(
warp_idx
==
1
);
bool
is_epilogue_warp
=
(
warp_idx
>=
4
&&
warp_idx
<=
7
);
if
(
is_epilogue_warp
&&
elect_one_sync
())
{
cute
::
prefetch
(
raw_pointer_cast
(
global_amax
));
}
typename
MainloopPipeline
::
Params
mainloop_pipeline_params
;
if
(
is_dma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_mma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Consumer
;
}
mainloop_pipeline_params
.
is_leader
=
cute
::
elect_one_sync
()
&&
is_dma_warp
;
mainloop_pipeline_params
.
transaction_bytes
=
kTmaTransactionBytes
;
mainloop_pipeline_params
.
initializing_warp
=
0
;
MainloopPipeline
mainloop_pipeline
(
shared_storage
.
mainloop
,
mainloop_pipeline_params
,
cluster_shape
,
cute
::
true_type
{},
// Perform barrier init
cute
::
true_type
{});
// Delay mask calculation
MainloopPipelineState
mainloop_pipe_consumer_state
;
MainloopPipelineState
mainloop_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
MainloopPipeline
>
();
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
4
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineState
=
typename
AccumulatorPipeline
::
PipelineState
;
AccumulatorPipelineState
accumulator_pipe_consumer_state
;
AccumulatorPipelineState
accumulator_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
AccumulatorPipeline
>
();
typename
AccumulatorPipeline
::
Params
accumulator_pipeline_params
;
if
(
is_mma_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_epilogue_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Consumer
;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params
.
producer_arv_count
=
1
;
accumulator_pipeline_params
.
consumer_arv_count
=
size
(
AtomThrShapeMNK
{})
*
128
;
accumulator_pipeline_params
.
initializing_warp
=
1
;
AccumulatorPipeline
accumulator_pipeline
(
shared_storage
.
accumulator
,
accumulator_pipeline_params
,
cluster_shape
,
cute
::
true_type
{},
// Perform barrier init
cute
::
true_type
{});
// Delay mask calculation
if
(
warp_idx
==
2
&&
elect_one_sync
())
{
cute
::
initialize_barrier
(
shared_storage
.
tma_barrier
[
0
],
/* num_threads */
1
);
}
__syncthreads
();
using
TMEM_LOAD_NEW
=
cute
::
SM100
::
TMEM
::
LOAD
::
SM100_TMEM_LOAD_32dp32b64x
;
if
(
is_dma_warp
)
{
if
(
elect_one_sync
())
{
cute
::
set_barrier_transaction_bytes
(
shared_storage
.
tma_barrier
[
0
],
kTmaRhtTensorTransactionBytes
);
copy
(
tma_load_b
.
with
(
shared_storage
.
tma_barrier
[
0
],
tma_mcast_mask_b
),
tBgB
(
_
,
0
,
0
),
tBsB
(
_
,
0
));
}
cute
::
wait_barrier
(
shared_storage
.
tma_barrier
[
0
],
0
/*tma_phase_bit*/
);
do
{
bool
is_first_wave
=
linear_tile_idx
==
blockIdx
.
x
;
uint32_t
skip_wait
=
is_first_wave
;
auto
tAgA_mk
=
tAgA
(
_
,
tile_idx_m
,
_
);
int
k_tile
=
0
;
auto
barrier_token
=
mainloop_pipeline
.
producer_try_acquire
(
mainloop_pipe_producer_state
,
skip_wait
);
CUTE_NO_UNROLL
while
(
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
)
{
int
k_tile_idx_n
=
tile_idx_n
+
k_tile
;
++
k_tile
;
skip_wait
=
(
is_first_wave
&&
k_tile
<
MainloopPipelineStageCount
);
mainloop_pipeline
.
producer_acquire
(
mainloop_pipe_producer_state
,
barrier_token
);
using
BarrierType
=
typename
MainloopPipeline
::
ProducerBarrierType
;
BarrierType
*
tma_barrier
=
mainloop_pipeline
.
producer_get_barrier
(
mainloop_pipe_producer_state
);
int
write_stage
=
mainloop_pipe_producer_state
.
index
();
++
mainloop_pipe_producer_state
;
barrier_token
=
mainloop_pipeline
.
producer_try_acquire
(
mainloop_pipe_producer_state
,
skip_wait
);
if
(
cute
::
elect_one_sync
())
{
copy
(
tma_load_a
.
with
(
*
tma_barrier
,
tma_mcast_mask_a
),
tAgA_mk
(
_
,
k_tile_idx_n
),
tAsA
(
_
,
write_stage
));
}
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
mainloop_pipeline
.
producer_tail
(
mainloop_pipe_producer_state
);
}
else
if
(
is_mma_warp
)
{
mma
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
tmem_allocation_result_barrier
.
arrive
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_mma
.
data
()
=
tmem_base_ptr
;
do
{
uint32_t
skip_wait
=
K_TILE_MAX
<=
0
;
auto
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
CUTE_NO_UNROLL
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
;
)
{
mainloop_pipeline
.
consumer_wait
(
mainloop_pipe_consumer_state
,
barrier_token
);
int
read_stage
=
mainloop_pipe_consumer_state
.
index
();
auto
tCrA_mk
=
tCrA
(
_
,
_
,
_
,
read_stage
);
auto
tCrB_nk
=
tCrB
(
_
,
_
,
0
,
0
);
CUTE_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
)
/
4
;
++
k_block
)
{
accumulator_pipeline
.
producer_acquire
(
accumulator_pipe_producer_state
);
CUTE_UNROLL
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
auto
accumulators
=
bulk_tmem_mma
(
_
,
_
,
_
,
accumulator_pipe_producer_state
.
index
()
*
4
+
i
);
gemm
(
mma
,
tCrA_mk
(
_
,
_
,
k_block
*
4
+
i
),
tCrB_nk
,
accumulators
);
}
accumulator_pipeline
.
producer_commit
(
accumulator_pipe_producer_state
);
++
accumulator_pipe_producer_state
;
}
auto
curr_mainloop_pipe_consumer_state
=
mainloop_pipe_consumer_state
;
++
mainloop_pipe_consumer_state
;
++
k_tile
;
skip_wait
=
k_tile
>=
K_TILE_MAX
;
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
mainloop_pipeline
.
consumer_release
(
curr_mainloop_pipe_consumer_state
);
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
tmem_allocator
.
release_allocation_lock
();
accumulator_pipeline
.
producer_tail
(
accumulator_pipe_producer_state
);
tmem_allocator
.
free
(
tmem_base_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
else
if
(
is_epilogue_warp
)
{
const
float
global_amax_val
=
*
global_amax
;
static
constexpr
int
FragmentSize
=
256
/
sizeof_bits_v
<
TC
>
;
tmem_allocation_result_barrier
.
arrive_and_wait
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_epilogue
.
data
()
=
tmem_base_ptr
;
int
thread_idx
=
threadIdx
.
x
%
128
;
Tensor
tCgC
=
thr_mma_epilogue
.
partition_C
(
gC_mn
);
// (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N)
auto
tiled_t2r
=
make_tmem_copy
(
TMEM_LOAD_NEW
{},
bulk_tmem_epilogue
(
_
,
_
,
_
,
_0
{}));
auto
tiled_r2g
=
make_tiled_copy_D
(
Copy_Atom
<
SM100_STORE_256bit_CACHE_NOALLOCATION
,
TC
>
{},
tiled_t2r
);
auto
thr_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
auto
thr_r2g
=
tiled_r2g
.
get_slice
(
thread_idx
);
// NVFP4 non-E8 recipe constants and global scales
static
constexpr
float
fp4_max
=
6.0
f
;
const
float
global_encode_scale
=
ComputeGlobalEncodeScaleFP4
(
global_amax_val
);
const
float
global_decode_scale
=
1.0
f
/
global_encode_scale
;
auto
sfd_converter
=
cutlass
::
NumericConverter
<
TSFC
,
float
>
{};
do
{
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
tile_idx_n
<
tiles_in_n
;
++
k_tile
)
{
Tensor
tCgC_mn
=
tCgC
(
_
,
_
,
_
,
tile_idx_m
,
tile_idx_n
+
k_tile
);
Tensor
tCgSFC_mn
=
gSFC_mn
(
_
,
_
,
tile_idx_m
,
tile_idx_n
+
k_tile
);
accumulator_pipeline
.
consumer_wait
(
accumulator_pipe_consumer_state
);
auto
tCtC
=
bulk_tmem_epilogue
(
_
,
_
,
_
,
accumulator_pipe_consumer_state
.
index
());
Tensor
tDtC
=
thr_t2r
.
partition_S
(
tCtC
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDgC
=
thr_t2r
.
partition_D
(
tCgC_mn
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tTR_rAcc
=
make_tensor
<
ElementAccumulator
>
(
shape
(
tDgC
));
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDrC
=
make_tensor
<
TC
>
(
shape
(
tDgC
));
Tensor
tTR_rAcc_frag
=
recast
<
cutlass
::
Array
<
ElementAccumulator
,
FragmentSize
>>
(
coalesce
(
tTR_rAcc
));
Tensor
tDrC_frag
=
recast
<
cutlass
::
Array
<
TC
,
FragmentSize
>>
(
coalesce
(
tDrC
));
Tensor
src
=
thr_r2g
.
retile_S
(
tDrC
);
Tensor
dst
=
thr_r2g
.
retile_D
(
tDgC
);
Tensor
tCgSFC
=
make_tensor
(
tCgSFC_mn
.
data
(),
make_layout
(
make_shape
(
shape
(
tCgSFC_mn
),
Int
<
1
>
{},
Int
<
1
>
{}),
make_stride
(
stride
(
tCgSFC_mn
),
Int
<
0
>
{},
Int
<
0
>
{})
));
Tensor
tDgSFC
=
filter
(
thr_t2r
.
partition_D
(
tCgSFC
));
Tensor
tDrSFC
=
make_tensor
<
TSFC
>
(
shape
(
tDgSFC
));
static
constexpr
int
NumVecs
=
size
(
tDgC
)
/
VectorSize
;
Tensor
tC_rRowSFD_frg
=
recast
<
cutlass
::
Array
<
TSFC
,
NumVecs
>>
(
tDrSFC
);
cutlass
::
maximum_absolute_value_reduction
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
,
true
>
amax_reduction
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
vec_maxs
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
pvscales
;
// TMEM_LOAD
copy
(
tiled_t2r
,
tDtC
,
tTR_rAcc
);
cutlass
::
arch
::
fence_view_async_tmem_load
();
accumulator_pipeline
.
consumer_release
(
accumulator_pipe_consumer_state
);
++
accumulator_pipe_consumer_state
;
// Cast data from FP32 to BF16 to FP32.
auto
convert_accum_to_bf16
=
cutlass
::
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
ElementAccumulator
,
FragmentSize
>
{};
auto
convert_bf16_to_accum
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
cutlass
::
bfloat16_t
,
FragmentSize
>
{};
tTR_rAcc_frag
(
_0
{})
=
convert_bf16_to_accum
(
convert_accum_to_bf16
(
tTR_rAcc_frag
(
_0
{})));
auto
compute_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
*>
(
tTR_rAcc_frag
.
data
());
auto
output_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
TC
,
VectorSize
>
*>
(
tDrC_frag
.
data
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
vec_maxs
[
v
]
=
amax_reduction
(
ElementAccumulator
(
0
),
compute_frgs
[
v
]);
}
pvscales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
vec_maxs
,
fp4_max
);
pvscales
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
pvscales
,
global_encode_scale
);
auto
pvscales_cvted
=
cutlass
::
NumericArrayConverter
<
TSFC
,
ElementAccumulator
,
NumVecs
>
{}(
pvscales
);
tC_rRowSFD_frg
(
_0
{})
=
pvscales_cvted
;
auto
qpvscale_ups
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
TSFC
,
NumVecs
>
{}(
tC_rRowSFD_frg
(
_0
{}));
auto
qpvscale_scaled
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
qpvscale_ups
,
global_decode_scale
);
auto
acc_scales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
1.0
,
qpvscale_scaled
);
// Initialize RNG for tile
const
size_t
rng_sequence
=
thread_idx
+
k_tile
*
256
+
linear_tile_idx
*
K_TILE_MAX
*
256
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
uint4
{
0
,
0
,
0
,
0
};
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
auto
acc_scale
=
cutlass
::
minimum_with_nan_propagation
<
ElementAccumulator
>
{}(
acc_scales
[
v
],
cutlass
::
platform
::
numeric_limits
<
ElementAccumulator
>::
max
());
// auto acc_scale = acc_scales[v];
if
constexpr
(
kEnableStochasticRounding
)
{
random_uint4
=
dist
.
generate4
(
rng
);
output_frgs
[
v
]
=
StochasticNumericConverter
(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
),
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
4
>*>
(
&
random_uint4
));
}
else
{
output_frgs
[
v
]
=
cutlass
::
NumericArrayConverter
<
TC
,
ElementAccumulator
,
VectorSize
>
{}(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
));
}
}
copy
(
tiled_r2g
,
src
,
dst
);
copy
(
AutoVectorizingCopyWithAssumedAlignment
<
128
>
{},
tDrSFC
,
tDgSFC
);
}
linear_tile_idx
+=
gridDim
.
x
;
tile_idx_m
=
linear_tile_idx
%
tiles_in_m
;
tile_idx_n
=
(
linear_tile_idx
/
tiles_in_m
)
*
K_TILE_MAX
;
}
while
(
tile_idx_m
<
tiles_in_m
&&
tile_idx_n
<
tiles_in_n
);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
TSFC
,
bool
kEnableStochasticRounding
=
false
>
void
rht_gemm_ntt_w_sfc
(
int
m
,
int
n
,
TA
const
*
A
,
TB
const
*
B
,
TC
*
C
,
TSFC
*
SFC
,
float
const
*
global_amax
,
const
size_t
*
rng_state
,
uint32_t
sm_count
,
cudaStream_t
stream
,
int
k_tile_size
=
2048
)
{
using
namespace
cute
;
// Define shapes (dynamic)
auto
M
=
static_cast
<
int
>
(
m
);
auto
N
=
static_cast
<
int
>
(
n
);
// Define strides (mixed)
auto
dA
=
make_stride
(
Int
<
1
>
{},
m
);
// (dM,dK)
auto
dB
=
make_stride
(
Int
<
1
>
{},
16
);
// (dN,dK)
auto
dC
=
make_stride
(
n
,
Int
<
1
>
{});
// (dM,dN)
auto
cga_shape
=
Shape
<
_1
,
_1
,
_1
>
{};
auto
cga_tile_shape
=
Shape
<
_128
,
_16
,
_16
>
{};
auto
cluster_tile_mainloop
=
Shape
<
_128
,
_16
,
_64
>
{};
// Construct the MMA
auto
mma
=
make_tiled_mma
(
SM100_MMA_F16BF16_SS
<
TA
,
TB
,
float
,
128
,
16
,
UMMA
::
Major
::
MN
,
UMMA
::
Major
::
MN
>
{},
Layout
<
Shape
<
_1
,
_1
>>
{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V
(
size
(
cga_shape
)
==
size
(
mma
));
CUTE_STATIC_ASSERT_V
(
evenly_divides
(
cga_tile_shape
,
tile_shape
(
mma
)));
// Determine the A and B shapes
auto
mma_shape_B
=
partition_shape_B
(
mma
,
make_shape
(
size
<
1
>
(
cga_tile_shape
),
size
<
2
>
(
cga_tile_shape
)));
using
TiledMma
=
decltype
(
mma
);
using
AtomThrID
=
typename
TiledMma
::
AtomThrID
;
using
SmemShape_M
=
decltype
(
shape_div
(
shape
<
0
>
(
cga_tile_shape
),
shape_div
(
shape
<
0
>
(
cga_tile_shape
),
size
<
0
>
(
cga_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_N
=
decltype
(
shape_div
(
shape
<
1
>
(
cga_tile_shape
),
shape_div
(
shape
<
1
>
(
cga_tile_shape
),
size
<
1
>
(
cga_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K
=
decltype
(
cute
::
get
<
2
>
(
cga_tile_shape
));
using
SmemLayoutAtomB
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TB
,
SmemShape_N
,
SmemShape_K
>
());
auto
mma_shape_A
=
partition_shape_A
(
mma
,
make_shape
(
size
<
0
>
(
cluster_tile_mainloop
),
size
<
2
>
(
cluster_tile_mainloop
)));
using
SmemShape_M_A
=
decltype
(
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
size
<
0
>
(
cluster_tile_mainloop
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K_A
=
decltype
(
cute
::
get
<
2
>
(
cluster_tile_mainloop
));
using
SmemLayoutAtomA
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TA
,
SmemShape_M_A
,
SmemShape_K_A
>
());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr
int
kBlackwellSmemSize
=
232448
;
// 232KB in bytes
constexpr
int
kBytesPerStage
=
cute
::
size
(
mma_shape_A
)
*
sizeof
(
TA
)
+
cute
::
size
(
mma_shape_B
)
*
sizeof
(
TB
);
constexpr
int
kReservedBytes
=
256
;
// Reserve for barriers and other uses
constexpr
int
kMaxStages
=
(
kBlackwellSmemSize
-
kReservedBytes
)
/
kBytesPerStage
;
auto
sP
=
Int
<
kMaxStages
>
{};
// SMEM pipelines
auto
sA
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomA
{},
append
(
mma_shape_A
,
sP
));
// (MMA,MMA_M,MMA_K,PIPE)
auto
sB
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomB
{},
append
(
mma_shape_B
,
sP
));
// (MMA,MMA_N,MMA_K,PIPE)
auto
sC
=
Layout
<
_1
>
{};
// XXX Dummy
// Create GMEM tensors
Tensor
tensorA
=
make_tensor
(
A
,
make_layout
(
make_shape
(
M
,
N
),
dA
));
// (M,N)
Tensor
tensorB
=
make_tensor
(
B
,
make_layout
(
make_shape
(
16
,
16
),
dB
));
// (16,16)
// Create the TiledCopy
auto
tma_load_a
=
make_tma_copy_A_sm100
(
SM90_TMA_LOAD
{},
tensorA
,
sA
(
_
,
_
,
_
,
0
),
cluster_tile_mainloop
,
mma
);
auto
tma_load_b
=
make_tma_copy_B_sm100
(
SM90_TMA_LOAD
{},
tensorB
,
sB
(
_
,
_
,
_
,
0
),
cga_tile_shape
,
mma
);
// Assert checks on tile sizes -- no predication
NVTE_CHECK
(
M
%
size
<
0
>
(
cga_tile_shape
)
==
0
,
"Inner dimension must be divisible by "
,
static_cast
<
size_t
>
(
size
<
0
>
(
cga_tile_shape
)),
" but got "
,
M
,
"."
);
NVTE_CHECK
(
N
%
(
4
*
size
<
1
>
(
cga_tile_shape
))
==
0
,
"Outer dimension must be divisible by "
,
4
*
static_cast
<
size_t
>
(
size
<
1
>
(
cga_tile_shape
)),
" but got "
,
N
,
"."
);
uint32_t
tiles
=
size
(
ceil_div
(
M
,
get
<
0
>
(
cga_tile_shape
)))
*
size
(
ceil_div
(
N
,
k_tile_size
));
tiles
=
(
tiles
<
sm_count
)
?
tiles
:
sm_count
;
dim3
dimBlock
(
256
);
dim3
dimCluster
(
size
<
0
>
(
cga_shape
),
size
<
1
>
(
cga_shape
),
size
<
2
>
(
cga_shape
));
dim3
dimGrid
(
tiles
,
1
,
1
);
int
smem_size
=
sizeof
(
SharedStorage
<
TA
,
TB
,
decltype
(
sA
),
decltype
(
sB
)
>
);
auto
*
kernel_ptr
=
&
rht_gemm_device
<
decltype
(
M
),
decltype
(
N
),
decltype
(
k_tile_size
),
decltype
(
cga_tile_shape
),
TA
,
decltype
(
dA
),
decltype
(
sA
),
decltype
(
tma_load_a
),
TB
,
decltype
(
dB
),
decltype
(
sB
),
decltype
(
tma_load_b
),
TC
,
decltype
(
dC
),
decltype
(
sC
),
TSFC
,
decltype
(
mma
),
kEnableStochasticRounding
>
;
bool
status
=
cudaFuncSetAttribute
(
*
kernel_ptr
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
status
!=
cudaSuccess
)
{
std
::
cerr
<<
"Error: Failed to set Shared Memory size."
<<
std
::
endl
;
return
;
}
(
*
kernel_ptr
)
<<<
dimGrid
,
dimBlock
,
smem_size
,
stream
>>>
(
M
,
N
,
k_tile_size
,
cga_tile_shape
,
A
,
dA
,
sA
,
tma_load_a
,
B
,
dB
,
sB
,
tma_load_b
,
C
,
dC
,
sC
,
SFC
,
mma
,
global_amax
,
rng_state
);
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template
<
typename
TA
,
typename
TB
,
typename
TC
,
typename
TSFC
,
bool
kEnableStochasticRounding
=
false
>
void
rht_gemm_ttt_wrapper
(
int
m
,
int
n
,
TA
const
*
A
,
TB
const
*
B
,
TC
*
C
,
TSFC
*
SFC
,
float
const
*
global_amax
,
const
size_t
*
rng_state
,
uint32_t
sm_count
,
cudaStream_t
stream
,
int
k_tile_size
=
1024
)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc
<
TA
,
TB
,
TC
,
TSFC
,
kEnableStochasticRounding
>
(
n
,
m
,
A
,
B
,
C
,
SFC
,
global_amax
,
rng_state
,
sm_count
,
stream
,
k_tile_size
);
}
}
// namespace
}
// namespace detail
// clang-format on
void
hadamard_transform_cast_fusion_columnwise
(
const
Tensor
&
input_
,
Tensor
&
output_
,
const
Tensor
&
hadamard_matrix_
,
QuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
hadamard_transform_cast_fusion_columnwise
);
// Check input and output tensors
NVTE_CHECK
(
input_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Input tensor must be BF16 tensor, but scaling mode is "
,
to_string
(
input_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
input_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Input tensor must be BF16 tensor, but dtype is "
,
to_string
(
input_
.
dtype
()),
"."
);
NVTE_CHECK
(
input_
.
dim
()
>=
2
,
"Input must be a 2D tensor."
);
const
SimpleTensor
&
input
=
input_
.
data
;
SimpleTensor
&
global_amax
=
output_
.
amax
;
SimpleTensor
&
output_t
=
output_
.
data
;
SimpleTensor
&
scale_inv_t
=
output_
.
scale_inv
;
// Stochastic rounding config
const
bool
use_stochastic_rounding
=
quant_config
.
stochastic_rounding
;
const
size_t
*
rng_state
=
nullptr
;
if
(
quant_config
.
rng_state
!=
nullptr
)
{
Tensor
&
rng_state_tensor
=
*
convertNVTETensor
(
quant_config
.
rng_state
);
NVTE_CHECK
(
rng_state_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_tensor
.
data
.
dptr
);
}
// Template arguments
using
TA
=
cute
::
bfloat16_t
;
using
TB
=
cute
::
bfloat16_t
;
using
TC
=
cutlass
::
float_e2m1_t
;
using
TSFC
=
cutlass
::
float_ue4m3_t
;
checkCuDriverContext
(
stream
);
// Check Hadamard matrix
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
hadamard_matrix_
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Hadamard matrix must be BF16 tensor, but scaling mode is "
,
to_string
(
hadamard_matrix_
.
scaling_mode
),
"."
);
NVTE_CHECK
(
hadamard_matrix_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Hadamard matrix must be BF16 tensor, but dtype is "
,
to_string
(
hadamard_matrix_
.
dtype
()),
"."
);
const
SimpleTensor
&
hadamard_matrix
=
hadamard_matrix_
.
data
;
NVTE_CHECK
(
(
hadamard_matrix_
.
shape
()
==
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
}),
"Hadamard matrix must have shape="
,
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
},
", but got shape="
,
hadamard_matrix_
.
shape
(),
"."
);
const
size_t
hadamard_dimension
=
hadamard_matrix
.
shape
[
0
];
const
size_t
ndim
=
input
.
shape
.
size
();
const
size_t
n
=
input
.
shape
[
ndim
-
1
];
size_t
m
=
1
;
for
(
size_t
i
=
0
;
i
<
ndim
-
1
;
++
i
)
{
m
*=
input
.
shape
[
i
];
}
auto
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
NVTE_CHECK
(
n
%
hadamard_dimension
==
0
,
"row_length must be divisible by hadamard_dimension."
);
NVTE_CHECK
(
m
%
hadamard_dimension
==
0
,
"num_rows must be divisible by hadamard_dimension"
);
int
k_tile_size
=
1024
;
if
(
m
==
8192
&&
n
==
5120
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
8192
&&
n
==
10240
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
8192
&&
n
==
2560
)
{
k_tile_size
=
1280
;
}
else
if
(
m
==
8192
&&
n
==
11328
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
8192
&&
n
==
512
)
{
k_tile_size
=
256
;
}
else
if
(
m
==
8192
&&
n
==
3584
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
11328
&&
n
==
8192
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
5120
&&
n
==
8192
)
{
k_tile_size
=
512
;
}
else
if
(
m
==
10240
&&
n
==
8192
)
{
k_tile_size
=
1024
;
}
else
if
(
m
==
2560
&&
n
==
8192
)
{
k_tile_size
=
1280
;
}
else
if
(
m
==
512
&&
n
==
8192
)
{
k_tile_size
=
256
;
}
else
if
(
m
==
3584
&&
n
==
8192
)
{
k_tile_size
=
512
;
}
else
if
(
m
<
1024
||
n
<
1024
)
{
k_tile_size
=
512
;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kUseStochasticRounding
,
detail
::
rht_gemm_ttt_wrapper
<
TA
,
TB
,
TC
,
TSFC
,
kUseStochasticRounding
>
(
/*m=*/
m
,
/*n=*/
n
,
/*A=*/
reinterpret_cast
<
TA
const
*>
(
input
.
dptr
),
/*B=*/
reinterpret_cast
<
TB
const
*>
(
hadamard_matrix
.
dptr
),
/*C=*/
reinterpret_cast
<
TC
*>
(
output_t
.
dptr
),
/*SFC=*/
reinterpret_cast
<
TSFC
*>
(
scale_inv_t
.
dptr
),
/*global_amax=*/
reinterpret_cast
<
float
const
*>
(
global_amax
.
dptr
),
/*rng_state=*/
rng_state
,
/*sm_count=*/
sm_count
,
/*stream=*/
stream
,
/*k_tile_size=*/
k_tile_size
););
}
}
// namespace transformer_engine
void
nvte_hadamard_transform_cast_fusion_columnwise
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_hadamard_transform_cast_fusion_columnwise
);
using
namespace
transformer_engine
;
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
hadamard_transform_cast_fusion_columnwise
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
hadamard_matrix
),
quant_config_cpp
,
stream
);
}
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
53fa872c
...
...
@@ -67,6 +67,11 @@ class CommOverlapCore {
std
::
vector
<
cudaStream_t
>
_stream_compute
;
cudaEvent_t
_start_compute
,
_stop_compute
,
_start_comm
,
_stop_comm
,
_comm_launch_event
;
private:
void
initialize
(
int
tp_size
,
int
num_splits
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
use_ce
,
bool
atomic_gemm
);
public:
CommOverlapCore
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual
~
CommOverlapCore
();
void
*
get_ubuf_dptr
()
{
return
_ubuf
.
dptr
();
}
void
set_ubuf_scale_inv
(
float
*
scale_inv
)
{
_ubuf_scale_inv
=
scale_inv
;
_ubuf_scale_inv_initialized
=
true
;
}
virtual
void
copy_into_buffer
(
cudaStream_t
stream
,
const
TensorWrapper
&
source
,
bool
local_chunk
,
bool
rowwise
=
true
)
{
NVTE_ERROR
(
"Operation is not implemented."
);
}
TensorWrapper
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
offset
,
const
std
::
vector
<
size_t
>
&
shape
);
TensorWrapper
get_buffer_chunk_like
(
const
TensorWrapper
&
source
,
size_t
offset
,
const
std
::
vector
<
size_t
>
&
shape
);
int
get_tp_size
()
{
return
_tp_size
;
}
bool
is_atomic_gemm
()
{
return
_atomic_gemm
;
}
bool
is_p2p_overlap
()
{
return
_is_p2p
;
}
...
...
@@ -150,6 +164,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t
_stream_comm
;
cudaEvent_t
_start_d2dcopy
;
private:
void
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
bool
rs_overlap_first_gemm
);
public:
CommOverlapBase
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -228,6 +246,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t
_stream_recv
;
cudaEvent_t
_stop_send
,
_stop_recv
;
private:
void
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
CommOverlapType
comm_type
,
bool
aggregate
);
public:
CommOverlapP2PBase
()
{}
// dummy constructor for exposing type to Python
...
...
@@ -241,6 +263,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual
~
CommOverlapP2PBase
();
void
copy_into_buffer
(
cudaStream_t
stream
,
const
TensorWrapper
&
source
,
bool
local_chunk
,
bool
rowwise
=
true
)
override
;
TensorWrapper
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
buffer_id
);
void
bulk_overlap
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
53fa872c
...
...
@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
=
5
,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX
=
0
,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX
=
1
,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX
=
2
,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
...
...
@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
...
...
@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
/*! \brief Compute dot product attention with packed QKV input.
*
...
...
@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
...
...
@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
...
...
@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void
nvte_fused_attn_bwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
...
...
@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
...
...
@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...
...
@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void
nvte_fused_attn_bwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q_padded
,
const
NVTETensor
cu_seqlens_
kv
_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTETensor
dKV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_
q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q
_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with separate Q, K and V.
*
...
...
@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
...
...
@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
...
...
@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
...
...
@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
...
...
@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void
nvte_fused_attn_bwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dK
,
NVTETensor
dV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
NVTETensor
dV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
*
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
53fa872c
...
...
@@ -15,9 +15,76 @@
#ifdef __cplusplus
extern
"C"
{
#endif
#endif
// __cplusplus
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
/*! \brief Configuration for matrix multiplication. */
typedef
void
*
NVTEMatmulConfig
;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
enum
NVTEMatmulConfigAttribute
{
/*! Bias tensor
*
* If provided, the bias tensor is applied in the GEMM epilogue.
*/
kNVTEMatmulConfigBiasTensor
=
0
,
/*! Bias gradient tensor
*
* If provided, the bias gradient tensor will be filled in the GEMM epilogue.
*/
kNVTEMatmulConfigDBiasTensor
=
1
,
/*! Whether to compute GELU in GEMM epilogue. */
kNVTEMatmulConfigWithGELUEpilogue
=
2
,
/*! Whether to compute GELU backward in GEMM epilogue. */
kNVTEMatmulConfigWithDGELUEpilogue
=
3
,
/*! Auxilliary tensor for GEMM epilogue.
*
* For GELU, this will be filled with the GELU input. For GELU
* backward, this is expected to already be filled with the GELU
* input.
*/
kNVTEMatmulConfigEpilogueAuxTensor
=
4
,
/*! Whether to use split accumulator for FP8 GEMM. */
kNVTEMatmulConfigUseSplitAccumulator
=
5
,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEMatmulConfigSMCount
=
6
,
kNVTEMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig
nvte_create_matmul_config
();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void
nvte_get_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void
nvte_set_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
);
/*! \brief Destroy a matrix multiplication configuration. */
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
...
...
@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = alpha * op(A) * op(B) + beta * C`
*
* \param[in] transa Whether to transpose A matrix.
* \param[in] transb Whether to transpose B matrix.
* \param[in] alpha Scaling factor applied to matmul output.
* \param[in] A A matrix.
* \param[in] B B matrix.
* \param[in] beta Scaling factor applied to C matrix.
* \param[in] C C matrix.
* \param[out] D Output matrix.
* \param[in] workspace Workspace tensor.
* \param[in] config Additional configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_cublas_gemm_v2
(
int
transa
,
int
transb
,
const
float
*
alpha
,
const
NVTETensor
A
,
const
NVTETensor
B
,
const
float
*
beta
,
const
NVTETensor
C
,
NVTETensor
D
,
NVTETensor
workspace
,
NVTEMatmulConfig
config
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
* allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
...
...
@@ -133,9 +223,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
...
...
@@ -165,7 +255,9 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // __cplusplus
#ifdef __cplusplus
/*! \namespace transformer_engine
*/
...
...
@@ -183,6 +275,89 @@ constexpr int num_batchgemm_streams = 1;
void
nvte_cublas_handle_init
();
/*! \struct MatmulConfigWrapper
* \brief C++ wrapper for NVTEMatmulConfig.
*/
class
MatmulConfigWrapper
{
public:
MatmulConfigWrapper
()
:
config_
{
nvte_create_matmul_config
()}
{}
MatmulConfigWrapper
(
const
MatmulConfigWrapper
&
)
=
delete
;
MatmulConfigWrapper
&
operator
=
(
const
MatmulConfigWrapper
&
)
=
delete
;
MatmulConfigWrapper
(
MatmulConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
MatmulConfigWrapper
&
operator
=
(
MatmulConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_matmul_config
(
config_
);
}
config_
=
other
.
config_
;
other
.
config_
=
nullptr
;
return
*
this
;
}
~
MatmulConfigWrapper
()
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_matmul_config
(
config_
);
config_
=
nullptr
;
}
}
/*! \brief Get the underlying NVTEMatmulConfig.
*
* \return NVTEMatmulConfig held by this MatmulConfigWrapper.
*/
operator
NVTEMatmulConfig
()
const
noexcept
{
return
config_
;
}
/*! \brief Set bias tensor. */
void
set_bias_tensor
(
NVTETensor
bias_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigBiasTensor
,
&
bias_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set bias gradient tensor. */
void
set_dbias_tensor
(
NVTETensor
dbias_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigDBiasTensor
,
&
dbias_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void
set_with_gelu_epilogue
(
bool
with_gelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithGELUEpilogue
,
&
with_gelu_epilogue
,
sizeof
(
bool
));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void
set_with_dgelu_epilogue
(
bool
with_dgelu_epilogue
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigWithDGELUEpilogue
,
&
with_dgelu_epilogue
,
sizeof
(
bool
));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
void
set_epilogue_aux_tensor
(
NVTETensor
epilogue_aux_tensor
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigEpilogueAuxTensor
,
&
epilogue_aux_tensor
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void
set_use_split_accumulator
(
bool
use_split_accumulator
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigUseSplitAccumulator
,
&
use_split_accumulator
,
sizeof
(
bool
));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void
set_sm_count
(
int
sm_count
)
{
nvte_set_matmul_config_attribute
(
config_
,
kNVTEMatmulConfigSMCount
,
&
sm_count
,
sizeof
(
int
));
}
private:
/*! \brief Wrapped NVTEMatmulConfig. */
NVTEMatmulConfig
config_
=
nullptr
;
};
}
// namespace transformer_engine
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_GEMM_H_
transformer_engine/common/include/transformer_engine/hadamard_transform.h
0 → 100644
View file @
53fa872c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file hadamard_transform.h
* \brief Functions for Hadamard transforms.
*/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern
"C"
{
#endif
/*! \brief Perform a randomized Hadamard transform on the input tensor.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*! \brief Perform the absolute maximum reduction on the input tensor with/without
* randomized hadamard transform. The rowwise result is the absolute maximum
* of the input tensor. The columnwise result is the absolute maximum of the
* input tensor transposed and applied randomized hadamard transformation.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*! \brief Perform the columnwise hadamard transform cast fusion.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] hadamard_matrix Hadamard matrix.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_hadamard_transform_cast_fusion_columnwise
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
transformer_engine/common/include/transformer_engine/recipe.h
View file @
53fa872c
...
...
@@ -124,6 +124,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t
start_offset
,
size_t
block_len
,
const
NVTEDType
out_dtype
,
cudaStream_t
stream
);
void
nvte_nvfp4_compute_per_tensor_scale
(
const
NVTETensor
inpA
,
const
bool
use_rowwise_amax_A
,
const
NVTETensor
inpB
,
const
bool
use_rowwise_amax_B
,
float
alpha_in
,
NVTETensor
alpha_out
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
53fa872c
...
...
@@ -73,6 +73,7 @@ enum NVTETensorParam {
kNVTEAmax
=
3
,
/*!< Amax tensor */
kNVTERowwiseScaleInv
=
4
,
/*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv
=
5
,
/*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax
=
6
,
/*!< Columnwise Amax tensor */
kNVTENumTensorParams
};
...
...
@@ -95,10 +96,9 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D
=
2
,
NVTE_BLOCK_SCALING_2D
=
3
,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
=
4
,
/*! Single scale per block of 16 elements consecutive in either
* rowwise or columnwise direction */
NVTE_NVFP4_1D_SCALING
=
4
,
NVTE_INVALID_SCALING
=
100
};
...
...
@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
=
3
,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState
=
4
,
/*! Whether to use 2D block scaling for NVFP4 */
kNVTEQuantizationConfigNVFP42DQuantization
=
5
,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding
=
6
,
kNVTEQuantizationConfigNumAttributes
};
...
...
@@ -458,6 +464,15 @@ inline bool is_fp4_dtype(const DType t) {
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
*/
inline
bool
is_high_precision_dtype
(
const
DType
t
)
{
return
t
==
DType
::
kFloat32
||
t
==
DType
::
kBFloat16
||
t
==
DType
::
kFloat16
;
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
...
...
@@ -593,6 +608,11 @@ class TensorWrapper {
return
set_parameter
(
kNVTEColumnwiseScaleInv
,
dptr
,
type
,
shape
);
}
template
<
typename
ShapeType
>
TensorWrapper
&
set_columnwise_amax
(
void
*
dptr
,
DType
type
,
const
ShapeType
&
shape
)
noexcept
{
return
set_parameter
(
kNVTEColumnwiseAmax
,
dptr
,
type
,
shape
);
}
// Parameter getters
NVTEBasicTensor
get_parameter
(
const
NVTETensorParam
param
)
const
noexcept
{
...
...
@@ -617,6 +637,10 @@ class TensorWrapper {
return
get_parameter
(
kNVTEColumnwiseScaleInv
);
}
NVTEBasicTensor
get_columnwise_amax
()
const
noexcept
{
return
get_parameter
(
kNVTEColumnwiseAmax
);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
...
...
@@ -865,6 +889,24 @@ class QuantizationConfigWrapper {
&
format
,
sizeof
(
Float8BlockScaleTensorFormat
));
}
/*! \brief Set stochastic rounding state */
void
set_rng_state
(
NVTETensor
rng_state
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigRNGState
,
&
rng_state
,
sizeof
(
NVTETensor
));
}
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void
set_nvfp4_2d_quantization
(
bool
nvfp4_2d_quantization
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigNVFP42DQuantization
,
&
nvfp4_2d_quantization
,
sizeof
(
bool
));
}
/*! \brief Set whether to use stochastic rounding */
void
set_stochastic_rounding
(
bool
stochastic_rounding
)
{
nvte_set_quantization_config_attribute
(
config_
,
kNVTEQuantizationConfigStochasticRounding
,
&
stochastic_rounding
,
sizeof
(
bool
));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig
config_
=
nullptr
;
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
53fa872c
...
...
@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
...
...
@@ -65,11 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_mxfp_scaling
(
z
->
scaling_mode
),
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
53fa872c
...
...
@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
...
...
@@ -51,11 +51,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool
is_aligned
=
true
;
#ifdef USE_ROCM
NVTE_CHECK
(
!
is_mxfp_scaling
(
z
->
scaling_mode
),
!
is_mxfp
8
_scaling
(
z
->
scaling_mode
),
"Cudnn backend is need by mxfp scaling mode for normalization! Not surpported in rocm yet."
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#else
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp
8
_scaling
(
z
->
scaling_mode
);
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
...
...
transformer_engine/common/recipe/__init__.py
View file @
53fa872c
...
...
@@ -4,7 +4,6 @@
"""This module provides predefined FP8 recipes."""
from
__future__
import
annotations
import
warnings
import
os
from
enum
import
Enum
from
typing
import
Literal
,
Optional
,
Union
,
Callable
,
NamedTuple
...
...
@@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple):
class
Format
(
Enum
):
"""
Supported FP8 formats.
Supported FP4 formats.
Values
------
E2M1 :
All FP4 tensors are in e2m1 format
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
...
...
@@ -35,6 +37,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
"""
E2M1
=
_FormatHelper
(
max_fwd
=
6
,
max_bwd
=
6
)
E4M3
=
_FormatHelper
(
max_fwd
=
448
,
max_bwd
=
448
)
E5M2
=
_FormatHelper
(
max_fwd
=
57344
,
max_bwd
=
57344
)
HYBRID
=
_FormatHelper
(
max_fwd
=
E4M3
.
max_fwd
,
max_bwd
=
E5M2
.
max_bwd
)
...
...
@@ -42,9 +45,13 @@ class Format(Enum):
@
dataclass
(
frozen
=
True
)
class
MMParams
:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""Matrix multiplication options.
Parameters
----------
use_split_accumulator : bool, default = `True`
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
use_split_accumulator
:
bool
=
True
...
...
@@ -55,10 +62,24 @@ class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
"""
power_2_scale
:
bool
=
False
amax_epsilon
:
float
=
0.0
random_hadamard_transform
:
bool
=
False
stochastic_rounding
:
bool
=
False
fp4_2d_quantization
:
bool
=
False
def
__repr__
(
self
)
->
str
:
return
(
f
"Qparams(
\n
power_2_scale=
{
self
.
power_2_scale
}
,
\n
"
f
"amax_epsilon=
{
self
.
amax_epsilon
}
,
\n
"
f
"random_hadamard_transform=
{
self
.
random_hadamard_transform
}
,
\n
"
f
"stochastic_rounding=
{
self
.
stochastic_rounding
}
,
\n
"
f
"fp4_2d_quantization=
{
self
.
fp4_2d_quantization
}
\n
)"
)
class
Recipe
:
...
...
@@ -66,6 +87,10 @@ class Recipe:
Base recipe class.
"""
def
nvfp4
(
self
):
"""Whether the given recipe is NVFP4 1D block scaling."""
return
isinstance
(
self
,
NVFP4BlockScaling
)
def
mxfp8
(
self
):
"""Whether the given recipe is MXFP8 block scaling."""
return
isinstance
(
self
,
MXFP8BlockScaling
)
...
...
@@ -184,6 +209,7 @@ class DelayedScaling(Recipe):
f
"margin=
{
self
.
margin
}
, "
f
"format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"amax_history_len=
{
self
.
amax_history_len
}
, "
f
"reduce_amax=
{
self
.
reduce_amax
}
, "
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
"
)
...
...
@@ -201,10 +227,11 @@ class Float8CurrentScaling(Recipe):
pass.
"""
use_power_2_scales
:
bool
=
os
.
getenv
(
"NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES"
,
"0"
)
==
"1"
fp8_format
:
Format
=
Format
.
HYBRID
fp8_quant_fwd_inp
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_weight
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_bwd_grad
=
QParams
(
power_2_scale
=
False
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_inp
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_quant_fwd_weight
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_quant_bwd_grad
=
QParams
(
power_2_scale
=
use_power_2_scales
,
amax_epsilon
=
0.0
)
fp8_gemm_fprop
:
MMParams
=
MMParams
(
use_split_accumulator
=
False
)
fp8_gemm_dgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
fp8_gemm_wgrad
:
MMParams
=
MMParams
(
use_split_accumulator
=
True
)
...
...
@@ -213,9 +240,6 @@ class Float8CurrentScaling(Recipe):
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
assert
(
not
self
.
fp8_dpa
and
not
self
.
fp8_mha
),
"FP8 attention is not supported for Float8CurrentScaling."
def
__repr__
(
self
)
->
str
:
return
(
...
...
@@ -351,3 +375,84 @@ class Float8BlockScaling(Recipe):
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
"
)
@
dataclass
()
class
NVFP4BlockScaling
(
Recipe
):
"""
Use the NVFP4 scaling strategy.
This is a 2-level block scaling strategy. In level 1, each group of
16 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E4M3 (4 bits of exponent,
3 bits of mantissa). In level 2, a global per tensor FP32 scaling
factor is used to scale the entire tensor.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
Parameters
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
fp8_format : {Format.E4M3}, default = Format.E4M3
FP8 data type. Only E4M3 is supported.
fp8_dpa: bool, default = `False`
FP8 dot product attention. Not yet supported.
fp8_mha: bool, default = `False`
FP8 multi-head attention. Not yet supported.
"""
# Configuration envvars
disable_rht
:
bool
=
os
.
getenv
(
"NVTE_NVFP4_DISABLE_RHT"
,
"0"
)
==
"1"
disable_stochastic_rounding
:
bool
=
(
os
.
getenv
(
"NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING"
,
"0"
)
==
"1"
)
disable_2d_quantization
:
bool
=
os
.
getenv
(
"NVTE_NVFP4_DISABLE_2D_QUANTIZATION"
,
"0"
)
==
"1"
fp4_format
:
Format
=
Format
.
E2M1
fp8_format
:
Format
=
Format
.
E4M3
# Not applying quantization to attention for now
fp8_dpa
:
bool
=
False
fp8_mha
:
bool
=
False
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp4_format
==
Format
.
E2M1
,
"Only E2M1 is supported for NVFP4 scaling"
assert
self
.
fp8_format
==
Format
.
E4M3
,
"Only E4M3 is supported for NVFP4 scaling"
# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self
.
fp4_quant_fwd_inp
=
QParams
(
random_hadamard_transform
=
not
self
.
disable_rht
,
stochastic_rounding
=
False
,
fp4_2d_quantization
=
False
,
)
self
.
fp4_quant_fwd_weight
=
QParams
(
random_hadamard_transform
=
False
,
stochastic_rounding
=
False
,
fp4_2d_quantization
=
not
self
.
disable_2d_quantization
,
)
self
.
fp4_quant_bwd_grad
=
QParams
(
random_hadamard_transform
=
not
self
.
disable_rht
,
stochastic_rounding
=
not
self
.
disable_stochastic_rounding
,
fp4_2d_quantization
=
False
,
)
def
__repr__
(
self
)
->
str
:
return
(
f
"recipe_type=
{
self
.
__class__
.
__name__
}
, "
f
"fp4_format=
{
str
(
self
.
fp4_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_format=
{
str
(
self
.
fp8_format
).
split
(
'.'
)[
1
]
}
, "
f
"fp8_dpa=
{
self
.
fp8_dpa
}
, "
f
"fp8_mha=
{
self
.
fp8_mha
}
, "
f
"fp4_quant_fwd_inp=
{
self
.
fp4_quant_fwd_inp
}
, "
f
"fp4_quant_fwd_weight=
{
self
.
fp4_quant_fwd_weight
}
, "
f
"fp4_quant_bwd_grad=
{
self
.
fp4_quant_bwd_grad
}
, "
)
transformer_engine/common/recipe/current_scaling.cu
View file @
53fa872c
...
...
@@ -27,6 +27,13 @@ namespace {
constexpr
int
amax_kernel_threads
=
512
;
__launch_bounds__
(
1
)
__global__
void
zero_amax_kernel
(
float
*
amax_ptr
,
const
float
*
noop_ptr
)
{
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
*
amax_ptr
=
0
;
}
template
<
int
nvec
,
bool
aligned
,
typename
InputType
>
__launch_bounds__
(
amax_kernel_threads
)
__global__
void
amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
...
...
@@ -131,7 +138,8 @@ template <int nvec, typename InputType>
void
launch_amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
const
float
*
noop_ptr
,
cudaStream_t
stream
)
{
// Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
));
zero_amax_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
amax
,
noop_ptr
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
// Return immediately if tensor is empty
if
(
N
==
0
)
{
...
...
@@ -216,15 +224,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK
(
output_
!=
nullptr
,
"Invalid output tensor (got NULL)"
);
auto
&
output
=
*
convertNVTETensorCheck
(
output_
);
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
NVTE_CHECK
(
output
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
output
.
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling or "
"NVFP4 1D scaling, "
"but got scaling_mode="
,
to_string
(
output
.
scaling_mode
));
NVTE_CHECK
(
output
.
amax
.
numel
()
==
1
,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape="
,
output
.
amax
.
shape
,
")"
);
NVTE_CHECK
(
output
.
amax
.
dptr
!=
nullptr
,
NVTE_CHECK
(
output
.
amax
.
dptr
!=
nullptr
||
output
.
columnwise_amax
.
dptr
!=
nullptr
,
"Output tensor for amax computation has amax tensor without data"
);
NVTE_CHECK
(
output
.
amax
.
dtype
==
DType
::
kFloat32
,
"Output tensor for amax computation has invalid amax tensor "
...
...
@@ -243,11 +253,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
}
// Compute amax
float
*
amax_ptr
=
reinterpret_cast
<
float
*>
(
(
output
.
amax
.
dptr
!=
nullptr
)
?
output
.
amax
.
dptr
:
output
.
columnwise_amax
.
dptr
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
amax
.
dptr
),
input
.
data
.
numel
(),
noop_ptr
,
stream
););
// NOLINT(*)
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
amax_ptr
,
input
.
data
.
numel
(),
noop_ptr
,
stream
););
// NOLINT(*)
}
}
// anonymous namespace
...
...
transformer_engine/common/recipe/nvfp4.cu
0 → 100644
View file @
53fa872c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace
transformer_engine
{
namespace
nvfp4_recipe
{
// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0;
constexpr
float
factor_inv
=
1.0
/
(
6.0
*
6.0
*
448.0
*
448.0
);
// Kernel to compute alpha *= amax_A * amax_B / factor
__global__
void
compute_nvfp4_per_tensor_scale_kernel
(
float
alpha_in
,
const
float
*
amax_A
,
const
float
*
amax_B
,
float
*
alpha_out
)
{
// factor is defined in the enclosing namespace
*
alpha_out
=
alpha_in
*
(
*
amax_A
)
*
(
*
amax_B
)
*
factor_inv
;
}
}
// namespace nvfp4_recipe
}
// namespace transformer_engine
void
nvte_nvfp4_compute_per_tensor_scale
(
const
NVTETensor
inpA
,
const
bool
use_rowwise_amax_A
,
const
NVTETensor
inpB
,
const
bool
use_rowwise_amax_B
,
float
alpha_in
,
NVTETensor
alpha_out
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_nvfp4_compute_per_tensor_scale
);
using
namespace
transformer_engine
;
auto
*
tA
=
convertNVTETensor
(
inpA
);
auto
*
tB
=
convertNVTETensor
(
inpB
);
auto
*
tOut
=
convertNVTETensor
(
alpha_out
);
void
*
amax_A_ptr
=
use_rowwise_amax_A
?
tA
->
amax
.
dptr
:
tA
->
columnwise_amax
.
dptr
;
void
*
amax_B_ptr
=
use_rowwise_amax_B
?
tB
->
amax
.
dptr
:
tB
->
columnwise_amax
.
dptr
;
void
*
alpha_ptr
=
tOut
->
data
.
dptr
;
// check for not null pointers
NVTE_CHECK
(
amax_A_ptr
!=
nullptr
,
"amax_A_ptr is null"
);
NVTE_CHECK
(
amax_B_ptr
!=
nullptr
,
"amax_B_ptr is null"
);
NVTE_CHECK
(
alpha_ptr
!=
nullptr
,
"alpha_ptr is null"
);
nvfp4_recipe
::
compute_nvfp4_per_tensor_scale_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
alpha_in
,
reinterpret_cast
<
const
float
*>
(
amax_A_ptr
),
reinterpret_cast
<
const
float
*>
(
amax_B_ptr
),
reinterpret_cast
<
float
*>
(
alpha_ptr
));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
transformer_engine/common/swizzle/swizzle.cu
View file @
53fa872c
...
...
@@ -18,7 +18,9 @@
namespace
transformer_engine
{
namespace
{
constexpr
__device__
__host__
int
MXFP8_BLOCK_SIZE
=
32
;
constexpr
int
MXFP8_BLOCK_SIZE
=
32
;
constexpr
int
NVFP4_BLOCK_SIZE
=
16
;
constexpr
__device__
__host__
int
TB_DIM
=
32
;
constexpr
__device__
__host__
int
NEW_SF_TILE_DIM_K
=
16
;
constexpr
__device__
__host__
int
N_SF_PER_TD_PER_TILE
=
4
;
...
...
@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
const
int
original_K
=
kernel_args
.
original_k_list
[
tensor_id
];
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
// Get block index in grid. Emulate 2D grid.
const
int
num_tiles_k
=
K
/
SF_TILE_DIM_K
;
...
...
@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
}
// namespace
void
swizzle_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
if
(
!
is_fp8_dtype
(
input
->
dtype
())
||
is_delayed_tensor_scaling
(
input
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented caling mode "
+
to_string
(
input
->
scaling_mode
)
+
"."
);
}
NVTE_CHECK
(
input
->
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
input
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
input
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
||
input
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Input tensor has invalid scaling mode ("
,
to_string
(
input
->
scaling_mode
),
")."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
->
dtype
())
||
is_fp4_dtype
(
input
->
dtype
()),
"Input tensor has invalid dtype ("
,
to_string
(
input
->
dtype
()),
")."
);
// Do nothing if tensor is empty
if
(
input
->
data
.
numel
()
==
0
)
{
...
...
@@ -345,13 +349,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
CheckInputTensor
(
*
output
,
"scaling_factor_output"
);
auto
&
scaling_mode
=
input
->
scaling_mode
;
NVTE_CHECK
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Unsupported scaling mode for swizzling."
);
bool
nvfp4
=
scaling_mode
==
NVTE_NVFP4_1D_SCALING
;
// 1D block scaling, row-wise or colum-wise
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
const
int
m
=
input
->
has_data
()
?
input
->
scale_inv
.
shape
[
0
]
:
input
->
columnwise_scale_inv
.
shape
[
1
];
const
int
k
=
input
->
has_data
()
?
input
->
scale_inv
.
shape
[
1
]
:
input
->
columnwise_scale_inv
.
shape
[
0
];
int
m
,
k
;
if
(
input
->
has_data
())
{
m
=
input
->
scale_inv
.
shape
[
0
];
k
=
input
->
scale_inv
.
shape
[
1
];
}
else
{
if
(
nvfp4
)
{
m
=
input
->
columnwise_scale_inv
.
shape
[
0
];
k
=
input
->
columnwise_scale_inv
.
shape
[
1
];
}
else
{
m
=
input
->
columnwise_scale_inv
.
shape
[
1
];
k
=
input
->
columnwise_scale_inv
.
shape
[
0
];
}
}
constexpr
int
SF_TILE_DIM_M
=
128
;
constexpr
int
SF_TILE_DIM_K
=
4
;
...
...
@@ -375,16 +391,35 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const
bool
rowwise_swizzle
=
input
->
has_data
()
||
nvfp4
;
const
bool
columnwise_swizzle
=
input
->
has_columnwise_data
()
&&
!
nvfp4
;
dim3
block_size
(
TB_DIM
,
TB_DIM
);
if
(
input
->
has_data
()
)
{
if
(
rowwise_swizzle
)
{
int
vec_load_size
=
(
num_tiles_k
-
1
)
%
4
+
1
;
/* there is no int3 and misaligned if using int4/int2 */
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_first_dim
();
const
int
original_K
=
input
->
flat_last_dim
()
/
MXFP8_BLOCK_SIZE
;
int
original_M
,
original_K
;
void
*
input_scale_inv_ptr
,
*
output_scale_inv_ptr
;
if
(
!
nvfp4
||
input
->
has_data
())
{
int
block_scale_size
=
nvfp4
?
NVFP4_BLOCK_SIZE
:
MXFP8_BLOCK_SIZE
;
original_M
=
input
->
flat_first_dim
();
original_K
=
input
->
flat_last_dim
()
/
block_scale_size
;
input_scale_inv_ptr
=
input
->
scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
scale_inv
.
dptr
;
}
else
{
original_M
=
input
->
flat_last_dim
();
original_K
=
input
->
flat_first_dim
()
/
NVFP4_BLOCK_SIZE
;
input_scale_inv_ptr
=
input
->
columnwise_scale_inv
.
dptr
;
output_scale_inv_ptr
=
output
->
columnwise_scale_inv
.
dptr
;
}
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
...
...
@@ -392,21 +427,21 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#else
case
4
:
...
...
@@ -415,7 +450,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
...
...
@@ -423,7 +458,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
...
...
@@ -431,16 +466,15 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
d
ptr
,
output
->
scale_inv
.
d
ptr
,
m
,
k
,
original_M
,
original_K
);
input
_
scale_inv
_
ptr
,
output
_
scale_inv
_
ptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
if
(
input
->
has_
columnwise_
data
()
)
{
if
(
columnwise_
swizzle
)
{
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
/* no int3 and misaligned if using int4/int2 */
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
...
...
@@ -448,6 +482,9 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_last_dim
();
const
int
original_K
=
input
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK
(
!
nvfp4
,
"NVFP4 shouldn't end up here because it only needs rowwise swizzle"
);
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
...
...
@@ -481,8 +518,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
2
:
NVTE_CHECK_CUDA
(
...
...
@@ -490,8 +527,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
case
1
:
NVTE_CHECK_CUDA
(
...
...
@@ -499,20 +536,14 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// 2D block scaling
}
else
{
NVTE_ERROR
(
"Not implemented for scaling_mode "
+
to_string
(
input
->
scaling_mode
)
+
", trans."
);
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
...
...
@@ -650,6 +681,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
// TODO(nvfp4): Add NVFP4 support.
void
multi_tensor_swizzle_scaling_factors
(
const
std
::
vector
<
Tensor
*>&
input
,
std
::
vector
<
Tensor
*>&
output
,
cudaStream_t
stream
)
{
auto
num_tensors
=
input
.
size
();
...
...
@@ -776,7 +809,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
* WIP (Phuong):
* - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling.
*/
*/
void
nvte_swizzle_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swizzle_scaling_factors
);
using
namespace
transformer_engine
;
...
...
transformer_engine/common/transformer_engine.cpp
View file @
53fa872c
...
...
@@ -11,6 +11,7 @@
#include <cstring>
#include <iostream>
#include <mutex>
#include <utility>
#include "common.h"
#include "common/util/cuda_runtime.h"
...
...
@@ -67,8 +68,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return
"NVTE_DELAYED_TENSOR_SCALING"
;
case
NVTE_MXFP8_1D_SCALING
:
return
"NVTE_MXFP8_1D_SCALING"
;
case
NVTE_
FWD_
NVFP4_
BWD_MXFP8
_SCALING
:
return
"NVTE_
FWD_
NVFP4_
BWD_MXFP8
_SCALING"
;
case
NVTE_NVFP4_
1D
_SCALING
:
return
"NVTE_NVFP4_
1D
_SCALING"
;
case
NVTE_INVALID_SCALING
:
return
"NVTE_INVALID_SCALING"
;
}
...
...
@@ -98,12 +99,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
||
t
.
scaling_mode
==
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
)
{
if
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
// Need (4, 128) alignment even for e8 scaling factor
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
size_t
expected_x
,
expected_y
,
alignment
;
const
size_t
block_size_rowwise
=
(
t
.
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
?
32
:
16
;
const
size_t
block_size_rowwise
=
32
;
const
size_t
block_size_colwise
=
32
;
if
(
t
.
has_data
())
{
...
...
@@ -114,6 +114,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
block_size_rowwise
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
NVTE_CHECK
(
t
.
scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected "
,
expected
,
", got "
,
...
...
@@ -126,11 +127,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
alignment
;
alignment
=
block_alignment
[
0
];
expected_y
=
DIVUP
(
DIVUP
(
t
.
flat_last_dim
(),
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_x
,
expected_y
};
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected "
,
expected
,
", got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
else
if
(
t
.
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
{
if
(
t
.
has_data
())
{
const
size_t
expected_y
=
DIVUP_TO_MULTIPLE
(
t
.
flat_first_dim
(),
128
);
const
size_t
expected_x
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
t
.
flat_last_dim
(),
16lu
),
4
);
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_y
,
expected_x
};
NVTE_CHECK
(
t
.
scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid scale_inv shape (expected "
,
expected
,
", got "
,
t
.
scale_inv
.
shape
,
")"
);
}
if
(
t
.
has_columnwise_data
())
{
const
size_t
expected_y
=
DIVUP_TO_MULTIPLE
(
t
.
flat_last_dim
(),
128
);
const
size_t
expected_x
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
t
.
flat_first_dim
(),
16lu
),
4
);
const
auto
&
expected
=
std
::
vector
<
size_t
>
{
expected_y
,
expected_x
};
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
shape
==
expected
,
"Tensor
\"
"
,
name
,
"
\"
has invalid columnwise_scale_inv shape (expected "
,
expected
,
", got "
,
t
.
columnwise_scale_inv
.
shape
,
")"
);
}
}
}
}
...
...
@@ -158,6 +177,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
"(expected Float32 or Byte, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
if
(
is_fp4_dtype
(
type
))
{
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor input "
,
name
,
"_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got "
,
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor input "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP8 scaling factor input "
,
name
,
"_columnwise_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 input "
,
name
);
NVTE_CHECK
(
t
.
amax
.
dptr
==
nullptr
,
"Amax is not supported for non-FP8 input "
,
name
);
...
...
@@ -199,10 +238,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
"(expected Float32 or Float8E8M0, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
if
(
is_fp4_dtype
(
type
))
{
// FP4 output needs to have the scale_inv
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
"_scale_inverse has invalid dtype "
"(expected Float8E4M3, got "
,
to_string
(
t
.
scale_inv
.
dtype
),
")"
);
}
if
(
t
.
has_columnwise_data
())
{
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
!=
nullptr
,
"FP4 scaling factor output "
,
name
,
"_columnwise_scale_inverse must be allocated"
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dtype
==
DType
::
kFloat8E4M3
,
"FP4 scaling factor output "
,
name
,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float8E4M3, got "
,
to_string
(
t
.
columnwise_scale_inv
.
dtype
),
")"
);
}
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 output "
,
name
);
//
Note: amax is supported for non-FP8 output as it can be fused into the computation
//
and later used for quantization with no need to compute it separately
//
Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
//
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
...
...
@@ -507,6 +565,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
case
kNVTEColumnwiseScaleInv
:
t
->
columnwise_scale_inv
=
*
param
;
break
;
case
kNVTEColumnwiseAmax
:
t
->
columnwise_amax
=
*
param
;
break
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
}
...
...
@@ -530,6 +591,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
return
t
.
scale_inv
;
case
kNVTEColumnwiseScaleInv
:
return
t
.
columnwise_scale_inv
;
case
kNVTEColumnwiseAmax
:
return
t
.
columnwise_amax
;
default:
NVTE_ERROR
(
"Unknown tensor parameter!"
);
}
...
...
@@ -645,6 +708,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat
:
std
::
memcpy
(
&
config_
.
float8_block_scale_tensor_format
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigRNGState
:
std
::
memcpy
(
&
config_
.
rng_state
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigNVFP42DQuantization
:
std
::
memcpy
(
&
config_
.
nvfp4_2d_quantization
,
buf
,
attr_size
);
break
;
case
kNVTEQuantizationConfigStochasticRounding
:
std
::
memcpy
(
&
config_
.
stochastic_rounding
,
buf
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEQuantizationConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
...
...
transformer_engine/common/transpose/cast_transpose.h
View file @
53fa872c
...
...
@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
detail
{
...
...
@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const
bool
pow_2_scale
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
);
}
// namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
0 → 100644
View file @
53fa872c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace
transformer_engine
{
#if CUDA_VERSION >= 12080
namespace
quantize_transpose_nvfp4
{
namespace
{
using
std
::
int32_t
;
using
std
::
uint32_t
;
using
std
::
uint8_t
;
using
transformer_engine
::
detail
::
TypeExtrema
;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
curanddx
::
SM
<
800
>
()
+
curanddx
::
Thread
());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr
int
kThreadsPerWarp
=
32
;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr
int
kNFP4PerContainer
=
2
;
// Hyperparameters for performance tuning
constexpr
int
kTileDim
=
128
;
// constexpr int kScaleDim = 32;
constexpr
int
kNVecIn
=
8
;
// The number of elements each LDG touches
constexpr
int
kNVecOut
=
16
;
// The number of elements each STG touches
constexpr
int
kNVecSMem
=
2
;
// The number of elements each LDS/STS touches
constexpr
int
kThreadsPerBlock
=
256
;
// Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert
(
kNVecIn
%
kNVecSMem
==
0
,
"kNVecIn must be divisible by kNVecSMem"
);
static_assert
(
kNVecOut
%
kNVecSMem
==
0
,
"kNVecOut must be divisible by kNVecSMem"
);
constexpr
int
kSMemRow
=
kTileDim
;
constexpr
int
kSMemCol
=
(
kTileDim
/
kNVecSMem
)
+
1
;
constexpr
int
kSMemSize
=
kSMemRow
*
kSMemCol
*
kNVecSMem
;
constexpr
int
kNumThreadsLoad
=
kTileDim
/
kNVecIn
;
// 16
constexpr
int
kNumThreadsStore
=
kTileDim
/
kNVecOut
;
// 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert
(
kNumThreadsLoad
<=
kThreadsPerWarp
,
"kNumThreadsLoad must be <= kThreadsPerWarp"
);
static_assert
(
kNumThreadsStore
<=
kThreadsPerWarp
,
"kNumThreadsStore must be <= kThreadsPerWarp"
);
// for 2D block scaling, we need to reduce amax in warp
static
__device__
constexpr
unsigned
int
WARP_REDUCE_AMAX_GROUP_MASKS
[
8
]
=
{
0x01010101
,
0x02020202
,
0x04040404
,
0x08080808
,
0x10101010
,
0x20202020
,
0x40404040
,
0x80808080
};
// max for every group_size elements in warp
template
<
int
group_size
,
int
shfl_down_stride
>
__device__
__forceinline__
float
groupMax
(
float
val
,
unsigned
int
groupMask
)
{
for
(
int
offset
=
group_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
val
=
max
(
val
,
__shfl_down_sync
(
groupMask
,
val
,
offset
*
shfl_down_stride
));
}
return
val
;
}
template
<
typename
ScaleType
>
__device__
__forceinline__
ScaleType
ComputeDecodeScaleFP4
(
const
float
amax
,
const
float
global_encode_scale
)
{
float
decode_scale
=
amax
/
TypeExtrema
<
fp4e2m1
>::
max
;
decode_scale
=
decode_scale
*
global_encode_scale
;
decode_scale
=
fminf
(
decode_scale
,
TypeExtrema
<
float
>::
max
);
return
static_cast
<
ScaleType
>
(
decode_scale
);
}
template
<
typename
ScaleType
>
__device__
__forceinline__
float
ComputeEncodeScaleFP4
(
ScaleType
decode_scale
,
const
float
global_decode_scale
)
{
return
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
decode_scale
)
*
global_decode_scale
),
TypeExtrema
<
float
>::
max
);
}
template
<
typename
IType
,
typename
ScaleType
>
__device__
__forceinline__
float
ComputeOutputFP4
(
IType
input
,
float
encode_scale
)
{
return
static_cast
<
float
>
(
input
)
*
encode_scale
;
}
__device__
__forceinline__
float
ComputeGlobalEncodeScaleFP4
(
const
float
global_amax
)
{
constexpr
float
fp8_max
=
TypeExtrema
<
fp8e4m3
>::
max
;
constexpr
float
fp4_max
=
TypeExtrema
<
fp4e2m1
>::
max
;
float
global_encode_scale
=
fp8_max
*
fp4_max
/
global_amax
;
// If scale is infinity, return max value of float32
global_encode_scale
=
fminf
(
global_encode_scale
,
TypeExtrema
<
float
>::
max
);
// If global amax is 0 or infinity, return 1
if
(
global_amax
==
0.
f
||
global_encode_scale
==
0.
f
)
{
return
1.
f
;
}
return
global_encode_scale
;
}
__device__
__forceinline__
uint32_t
get_rbits
(
RNG
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
if
(
rnd_idx
==
4
)
{
rnd_idx
=
0
;
curanddx
::
uniform_bits
dist
;
random_uint4
=
dist
.
generate4
(
rng
);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const
uint32_t
*
const
rbits_arr
=
reinterpret_cast
<
uint32_t
*>
(
&
random_uint4
);
const
uint32_t
rbits
=
rbits_arr
[
rnd_idx
++
];
return
rbits
;
}
template
<
class
ScaleType
>
__device__
__forceinline__
size_t
scale_factor_swizzled_offset
(
size_t
row_idx
,
size_t
col_idx
,
uint32_t
col_length
)
{
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr
uint32_t
kTotalRowsPerBaseBlock
=
128
;
constexpr
uint32_t
kRowsPerBaseBlockCol
=
32
;
constexpr
uint32_t
kColsPerBaseBlockCol
=
4
;
const
size_t
rb
=
row_idx
/
kTotalRowsPerBaseBlock
;
const
size_t
rem
=
row_idx
%
kTotalRowsPerBaseBlock
;
const
size_t
d4
=
rem
/
kRowsPerBaseBlockCol
;
const
size_t
d3
=
rem
%
kRowsPerBaseBlockCol
;
const
size_t
cbg
=
col_idx
/
kColsPerBaseBlockCol
;
const
size_t
d5
=
col_idx
%
kColsPerBaseBlockCol
;
const
size_t
cbg_cnt
=
DIVUP
(
col_length
,
kColsPerBaseBlockCol
);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return
((
rb
*
cbg_cnt
+
cbg
)
*
kRowsPerBaseBlockCol
+
d3
)
*
16
+
d4
*
kColsPerBaseBlockCol
+
d5
;
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t
out_4x
;
asm
volatile
(
"{
\n
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5;
\n\t
"
"}"
:
"=h"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
asm
volatile
(
"{
\n
"
".reg.b8 f0;
\n\t
"
".reg.b8 f1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"}"
:
"=r"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
#else
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template
<
bool
kApplyStochasticRounding
>
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
if
constexpr
(
kApplyStochasticRounding
)
{
return
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
in01
,
in23
,
rbits
);
}
else
{
return
cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
rbits
);
}
}
template
<
bool
kReturnIdentity
,
bool
kReturnTranspose
,
bool
kIsE8Scaling
,
bool
kAligned
,
typename
CType
,
typename
IType
,
typename
OType
,
typename
ScaleType
,
bool
kSwizzledScale
,
bool
kApplyStochasticRounding
,
bool
kIs2DBlockScaling
>
__global__
void
__launch_bounds__
(
kThreadsPerBlock
)
block_scaled_1d_cast_transpose_kernel
(
const
IType
*
const
input
,
const
float
*
global_amax
,
OType
*
const
output_c
,
OType
*
const
output_t
,
ScaleType
*
const
tile_scales_inv_c
,
ScaleType
*
const
tile_scales_inv_t
,
const
size_t
row_length
,
const
size_t
num_rows
,
const
size_t
scale_stride_x
,
const
size_t
scale_stride_y
,
const
size_t
scale_t_stride_x
,
const
size_t
scale_t_stride_y
,
const
size_t
kScaleBlockDim
,
const
float
epsilon
,
const
size_t
*
rng_state
,
const
float
*
noop_ptr
)
{
constexpr
int
kNVecContainer
=
kNVecOut
/
kNFP4PerContainer
;
using
SMemVec
=
Vec
<
IType
,
kNVecSMem
>
;
using
OVec
=
Vec
<
OType
,
kNVecContainer
>
;
union
IVec
{
Vec
<
IType
,
kNVecIn
>
input_type
;
Vec
<
SMemVec
,
kNVecIn
/
kNVecSMem
>
smem_type
;
};
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
const
size_t
block_idx_x
=
blockIdx
.
x
;
const
size_t
block_idx_y
=
blockIdx
.
y
;
const
size_t
rng_sequence
=
threadIdx
.
x
+
block_idx_x
*
kThreadsPerBlock
+
block_idx_y
*
gridDim
.
x
*
kThreadsPerBlock
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG
rng
(
rng_seed
,
rng_sequence
,
rng_offset
);
curanddx
::
uniform_bits
dist
;
uint4
random_uint4
=
kApplyStochasticRounding
?
dist
.
generate4
(
rng
)
:
uint4
{
0
,
0
,
0
,
0
};
int
rnd_idx
=
0
;
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern
__shared__
char
smem_base
[];
SMemVec
*
smem
=
reinterpret_cast
<
SMemVec
*>
(
&
smem_base
[
0
]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if
constexpr
(
kIs2DBlockScaling
&&
kIsE8Scaling
)
{
return
;
}
if
constexpr
(
kIs2DBlockScaling
&&
!
kReturnIdentity
)
{
return
;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr
int
kFP4BlockScalingSize
=
16
;
constexpr
int
k2DBlockAmaxDim
=
kIs2DBlockScaling
?
(
kTileDim
/
kFP4BlockScalingSize
)
:
1
;
constexpr
int
kNumRowsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
;
// 4
constexpr
int
k2DBlockAmaxReduceDim
=
kIs2DBlockScaling
?
(
kFP4BlockScalingSize
/
kNumRowsPerWarp
)
:
1
;
__shared__
CType
amax_smem_red
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
][
k2DBlockAmaxReduceDim
];
__shared__
CType
amax_smem
[
k2DBlockAmaxDim
][
k2DBlockAmaxDim
];
// Step 1: Load input to shared memory
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsLoad
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsLoad
)
*
(
kNVecIn
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsLoad
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecIn
),
row_length
-
c_g
)
:
0
);
// For not aligned case
const
IType
*
input_g
=
&
input
[
r_g
*
row_length
+
c_g
];
// Input address in global memory
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
IVec
input_vec
;
// Step 1.1: Load from global memory (input) to registers
if
constexpr
(
kAligned
)
{
input_vec
.
input_type
.
load_from
(
input_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
input_vec
.
input_type
.
load_from_elts
(
input_g
,
0
,
num_ele
);
}
else
{
input_vec
.
input_type
.
clear
();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecIn
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem
[
r
*
kSMemCol
+
c
]
=
input_vec
.
smem_type
.
data
.
elt
[
i
];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g
+=
stride_g
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
__syncthreads
();
const
int
kNumThreadsReduce
=
kScaleBlockDim
/
kNVecOut
;
const
float
global_encode_scale
=
kIsE8Scaling
?
1.0
f
:
ComputeGlobalEncodeScaleFP4
(
global_amax
[
0
]);
const
float
global_decode_scale
=
1.0
/
global_encode_scale
;
// Step 2: Cast and store to output_c
if
constexpr
(
kReturnIdentity
)
{
constexpr
int
r_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// stride in rows of shared memory
constexpr
int
num_iterations
=
kTileDim
/
r_stride
;
const
int
c_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
(
kNVecOut
/
kNVecSMem
);
// Column in shared memory
int
r_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Row in shared memory
const
size_t
c_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Column in global memory
size_t
r_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Row in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
r_stride
)
*
row_length
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
row_length
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
row_length
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_c
[(
r_g
*
row_length
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
/
kNVecSMem
];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
int
c
=
c_s
+
i
;
int
r
=
r_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
// Step 2.2: Compute local amax
CType
amax
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
kNVecSMem
;
++
j
)
{
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
j
]));
}
}
// Step 2.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if
constexpr
(
kIs2DBlockScaling
)
{
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr
int
kNumRowsPerIter
=
kThreadsPerBlock
/
kNumThreadsStore
;
// 32
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
int
tid_in_warp_x
=
threadIdx
.
x
%
kNumThreadsStore
;
int
tid_in_warp_y
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumRowsPerWarp
;
CType
amax_warp_reduced
=
groupMax
<
kNumRowsPerWarp
,
kNumThreadsStore
>
(
amax
,
WARP_REDUCE_AMAX_GROUP_MASKS
[
tid_in_warp_x
]);
// now T0 ~ T8 in each warp has the reduced amax values
int
data_row_idx
=
iter
*
kNumRowsPerIter
+
warp_idx
*
kNumRowsPerWarp
+
tid_in_warp_y
;
if
(
tid_in_warp_y
==
0
)
{
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
[
warp_idx
%
k2DBlockAmaxReduceDim
]
=
amax_warp_reduced
;
}
__syncthreads
();
if
(
data_row_idx
%
kFP4BlockScalingSize
==
0
)
{
CType
amax_2d
=
0.0
;
for
(
int
i
=
0
;
i
<
k2DBlockAmaxReduceDim
;
i
++
)
{
amax_2d
=
fmaxf
(
amax_2d
,
amax_smem_red
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
][
i
]);
}
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
]
=
amax_2d
;
}
__syncthreads
();
// every thread now knows 2D amax
amax
=
amax_smem
[
data_row_idx
/
kFP4BlockScalingSize
][
tid_in_warp_x
];
}
// Step 2.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 2.5: Write scale_inv
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
<
num_rows
);
write_scale_inv
&=
(
c_g
<
row_length
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_y
*
kTileDim
+
r_s
;
size_t
col_idx
=
block_idx_x
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
;
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
row_length
,
kScaleBlockDim
));
tile_scales_inv_c
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_c
[
row_idx
*
scale_stride_y
+
col_idx
*
scale_stride_x
]
=
scale_inv
;
}
}
// Step 2.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNVecSMem
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
0
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
].
data
.
elt
[
1
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
0
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
i
+
1
].
data
.
elt
[
1
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 2.7: Store output_c
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
);
}
else
{
if
(
r_g
<
num_rows
)
{
output_vec
.
store_to_elts
(
output_g
,
0
,
num_ele
);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
r_s
+=
r_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
r_stride
;
}
}
}
// Step 3: Transpose, cast and store to output_t
if
constexpr
(
kReturnTranspose
)
{
constexpr
int
c_stride
=
kThreadsPerBlock
/
kNumThreadsStore
;
// Stride in columns of shared memory
constexpr
int
num_iterations
=
kTileDim
/
(
c_stride
*
kNVecSMem
);
const
int
r_s
=
(
threadIdx
.
x
%
kNumThreadsStore
)
*
kNVecOut
;
// Row in shared memory
int
c_s
=
threadIdx
.
x
/
kNumThreadsStore
;
// Column in shared memory
size_t
r_g
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
;
// Row in global memory
const
size_t
c_g
=
block_idx_y
*
kTileDim
+
r_s
;
// Column in global memory
const
size_t
stride_g
=
static_cast
<
size_t
>
(
c_stride
)
*
kNVecSMem
*
num_rows
;
// Stride in global memory
const
size_t
num_ele
=
(
c_g
<
num_rows
?
min
(
static_cast
<
size_t
>
(
kNVecOut
/
kNFP4PerContainer
),
(
num_rows
-
c_g
)
/
kNFP4PerContainer
)
:
0
);
// For not aligned case
OType
*
output_g
=
&
output_t
[(
r_g
*
num_rows
+
c_g
)
/
kNFP4PerContainer
];
// Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const
unsigned
src_lane
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
/
kNumThreadsReduce
*
kNumThreadsReduce
;
// This mask represents which threads should do the reduction together.
const
unsigned
mask
=
((
1
<<
kNumThreadsReduce
)
-
1
)
<<
src_lane
;
const
bool
is_src_lane
=
(
threadIdx
.
x
%
kNumThreadsReduce
)
==
0
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
SMemVec
smem_vec
[
kNVecOut
];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
int
r
=
r_s
+
i
;
int
c
=
c_s
;
smem_vec
[
i
]
=
smem
[
r
*
kSMemCol
+
c
];
}
#pragma unroll
for
(
int
smem_idx
=
0
;
smem_idx
<
kNVecSMem
;
++
smem_idx
)
{
// Step 3.2: Compute local amax
CType
amax
=
0
;
if
constexpr
(
kIs2DBlockScaling
)
{
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int
warp_idx
=
threadIdx
.
x
/
kThreadsPerWarp
;
// 0 ~ 7
constexpr
int
kNumColsPerWarp
=
kThreadsPerWarp
/
kNumThreadsStore
*
kNVecSMem
;
// 8 elements
constexpr
int
kNumWarpsPerBlock
=
kThreadsPerBlock
/
kThreadsPerWarp
;
// 8 warps per block
constexpr
int
kNumColsPerIter
=
kNumColsPerWarp
*
kNumWarpsPerBlock
;
int
tid_in_warp_x
=
(
threadIdx
.
x
/
kNumThreadsStore
)
%
kNumColsPerWarp
;
int
tid_in_warp_y
=
(
threadIdx
.
x
%
kThreadsPerWarp
)
%
kNumThreadsStore
;
int
data_col_idx
=
iter
*
kNumColsPerIter
+
warp_idx
*
kNumColsPerWarp
+
tid_in_warp_x
;
amax
=
amax_smem
[
tid_in_warp_y
][
data_col_idx
/
kFP4BlockScalingSize
];
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
;
++
i
)
{
amax
=
fmaxf
(
amax
,
fabsf
(
smem_vec
[
i
].
data
.
elt
[
smem_idx
]));
}
}
// Step 3.3: Reduce amax
if
constexpr
(
kIsE8Scaling
)
{
#pragma unroll
for
(
int
delta
=
kNumThreadsReduce
/
2
;
delta
>
0
;
delta
/=
2
)
{
const
float
other_amax
=
__shfl_down_sync
(
mask
,
amax
,
delta
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
}
amax
=
__shfl_sync
(
mask
,
amax
,
src_lane
);
}
// Step 3.4: Compute scale
ScaleType
scale_inv
=
ComputeDecodeScaleFP4
<
ScaleType
>
(
amax
,
global_encode_scale
);
float
encode_scale
=
ComputeEncodeScaleFP4
<
ScaleType
>
(
scale_inv
,
global_decode_scale
);
// Step 3.5: Write scale_inv_t
bool
write_scale_inv
=
is_src_lane
;
if
constexpr
(
!
kAligned
)
{
write_scale_inv
&=
(
r_g
+
smem_idx
<
row_length
);
write_scale_inv
&=
(
c_g
<
num_rows
);
}
if
(
write_scale_inv
)
{
size_t
row_idx
=
block_idx_x
*
kTileDim
+
c_s
*
kNVecSMem
+
smem_idx
;
size_t
col_idx
=
(
block_idx_y
*
(
kNumThreadsStore
/
kNumThreadsReduce
)
+
(
threadIdx
.
x
%
kNumThreadsStore
)
/
kNumThreadsReduce
);
if
constexpr
(
kSwizzledScale
)
{
size_t
offset
=
scale_factor_swizzled_offset
<
ScaleType
>
(
row_idx
,
col_idx
,
DIVUP
(
num_rows
,
kScaleBlockDim
));
tile_scales_inv_t
[
offset
]
=
scale_inv
;
}
else
{
tile_scales_inv_t
[
row_idx
*
scale_t_stride_y
+
col_idx
*
scale_t_stride_x
]
=
scale_inv
;
}
}
// Step 3.6: Quantize
OVec
output_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kNVecOut
/
kNFP4PerContainer
;
i
+=
2
)
{
// Pack two elements into __nv_bfloat162
float2
f2_a
;
float2
f2_b
;
f2_a
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_a
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
i
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
x
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)].
data
.
elt
[
smem_idx
],
encode_scale
);
f2_b
.
y
=
ComputeOutputFP4
<
IType
,
ScaleType
>
(
smem_vec
[
2
*
(
i
+
1
)
+
1
].
data
.
elt
[
smem_idx
],
encode_scale
);
const
uint32_t
rbits
=
kApplyStochasticRounding
?
get_rbits
(
rng
,
random_uint4
,
rnd_idx
)
:
0
;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1
out_4x
=
cvt_fp32_to_fp4_4x
<
kApplyStochasticRounding
>
(
f2_a
,
f2_b
,
rbits
);
output_vec
.
data
.
elt
[
i
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
0
];
output_vec
.
data
.
elt
[
i
+
1
]
=
reinterpret_cast
<
__nv_fp4x2_storage_t
*>
(
&
out_4x
)[
1
];
}
// Step 3.7: Store output_t
if
constexpr
(
kAligned
)
{
output_vec
.
store_to
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
);
}
else
{
if
(
r_g
+
smem_idx
<
row_length
)
{
output_vec
.
store_to_elts
(
output_g
+
smem_idx
*
num_rows
/
kNFP4PerContainer
,
0
,
num_ele
);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g
+=
stride_g
/
kNFP4PerContainer
;
c_s
+=
c_stride
;
if
constexpr
(
!
kAligned
)
{
r_g
+=
c_stride
*
kNVecSMem
;
}
}
}
}
}
// namespace
}
// namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace
detail
{
void
quantize_transpose_vector_blockwise_fp4
(
const
SimpleTensor
&
input
,
const
SimpleTensor
&
global_amax
,
SimpleTensor
&
scale_inv
,
SimpleTensor
&
scale_inv_t
,
SimpleTensor
&
output
,
SimpleTensor
&
output_t
,
const
float
epsilon
,
const
bool
return_identity
,
const
bool
return_transpose
,
const
bool
pow2_scale
,
const
bool
swizzled_scale
,
const
bool
use_stochastic_rounding
,
const
NVTETensor
rng_state_tensor
,
const
bool
use_2d_quantization
,
const
SimpleTensor
&
noop_tensor
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_vector_blockwise_fp4
);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK
(
!
pow2_scale
,
"No support for pow2_scale for MXFP4 for now"
);
if
(
!
return_identity
&&
!
return_transpose
)
{
return
;
}
if
(
use_2d_quantization
&&
!
return_identity
)
{
return
;
}
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_elements
=
row_length
;
size_t
num_rows
=
1
;
for
(
size_t
i
=
0
;
(
i
<
input
.
shape
.
size
()
-
1
)
&&
(
input
.
shape
.
size
()
>
0
);
++
i
)
{
num_rows
*=
input
.
shape
.
at
(
i
);
num_elements
*=
input
.
shape
.
at
(
i
);
}
// Early return if the input tensor is empty
if
(
num_elements
==
0
)
{
return
;
}
size_t
scale_stride_x
=
0
;
size_t
scale_stride_y
=
0
;
if
(
return_identity
)
{
scale_stride_x
=
1
;
scale_stride_y
=
scale_inv
.
shape
[
1
];
}
size_t
scale_t_stride_x
=
0
;
size_t
scale_t_stride_y
=
0
;
if
(
return_transpose
)
{
scale_t_stride_x
=
1
;
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
}
using
namespace
transformer_engine
::
quantize_transpose_nvfp4
;
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
static_cast
<
size_t
>
(
kTileDim
));
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
static_cast
<
size_t
>
(
kTileDim
));
// noop tensor for cuda graph
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop_tensor
.
dptr
);
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY
(
output
.
dtype
,
2
,
OutputType
,
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
using
ScaleType
=
fp8e4m3
;
constexpr
int
kScaleBlockDim
=
16
;
constexpr
bool
kPow2Scale
=
false
;
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_identity
,
kReturnIdentity
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
kReturnTranspose
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
swizzled_scale
,
kSwizzledScale
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kApplyStochasticRounding
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_2d_quantization
,
kIs2DBlockScaling
,
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
auto
kernel
=
block_scaled_1d_cast_transpose_kernel
<
kReturnIdentity
,
kReturnTranspose
,
kPow2Scale
,
kAligned
,
float
,
InputType
,
OutputType
,
ScaleType
,
kSwizzledScale
,
kApplyStochasticRounding
,
kIs2DBlockScaling
>
;
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
kernel
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
float
*>
(
global_amax
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
ScaleType
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
kScaleBlockDim
,
epsilon
,
rng_state
,
noop_ptr
);)
// kIs2DBlockScaling
)
// kApplyStochasticRounding
)
// kSwizzledScale
)
// kAligned
)
// kReturnTranspose
)
// kReturnIdentity
)
// OutputType
)
// InputType
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace detail
}
// namespace transformer_engine
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
53fa872c
...
...
@@ -603,6 +603,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if
constexpr
(
IS_DGATED
)
{
const
e8m0_t
biased_exponent_gate
=
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
...
...
@@ -833,6 +834,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx
::
mul_cvt_2x
(
out_gate_pair
,
in_gate
,
block_scale_inverse_2x_gate
);
}
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
...
...
@@ -956,6 +958,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
buff_size_aligned_out
;
const
size_t
shmem_size
=
grad_mem
+
(
in_act_mem
+
in_gate_mem
)
+
(
out_act_mem
+
out_gate_mem
)
+
TMA_SHMEM_ALIGNMENT
;
...
...
@@ -1274,7 +1277,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_gated
<
ParamOP
,
ActOP
>
(
gated_input
,
output
,
stream
);
}
}
}
else
if
(
is_mxfp_scaling
(
output
->
scaling_mode
))
{
}
else
if
(
is_mxfp
8
_scaling
(
output
->
scaling_mode
))
{
if
(
use_tma_kernels
)
{
cast_mxfp8_gated
<
IS_DGATED
,
ParamOP
,
ActOP
,
DActOP
>
(
grad
,
gated_input
,
output
,
stream
);
}
else
{
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
53fa872c
...
...
@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
...
...
@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
size_t
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
...
...
@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
act_in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
elt_input_mem
);
OType
*
out_rowwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
OType
*
out_rowwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
...
...
@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
size_t
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
out_colwise_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
out_colwise_
data_
sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
...
...
@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
size_t
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
const
int
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
int
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
int
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
if
(
rowwise_scale_is_within_bounds
)
{
scales_rowwise
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
...
...
@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
out
.
store_to
(
&
out_rowwise_sh
[
shmem_offset_rowwise
]);
out
.
store_to
(
&
out_rowwise_
data_
sh
[
shmem_offset_rowwise
]);
}
}
...
...
@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
size_
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
size_
t
global_offset_X
=
block_offset_X
;
const
size_
t
buff_offset
=
buff
*
BUFF_DIM
;
const
in
t
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
in
t
global_offset_X
=
block_offset_X
;
const
in
t
buff_offset
=
buff
*
BUFF_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_
data_
sh
[
buff_offset
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_sh
[
buff_offset
]));
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_
data_
sh
[
buff_offset
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
...
...
@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float
*
partial_dbias_rowwise
=
reinterpret_cast
<
float
*>
(
dshmem
);
constexpr
size_
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
constexpr
in
t
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
const
size_
t
shmem_thread_offset
=
const
in
t
shmem_thread_offset
=
tid_Y_rowwise
*
DBIAS_BUFF_WIDTH
+
tid_X_rowwise
*
(
SCALE_DIM_X
+
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
const
in
t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
in
t
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
const
size_
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
const
in
t
shmem_elt_idx
=
swizzled_group_offset
+
e
;
partial_dbias_rowwise
[
shmem_elt_idx
]
=
thread_dbias_rowwise
[
j
];
}
}
...
...
@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for
(
int
i
=
0
;
i
<
THREADS_Y
;
++
i
)
{
// Add extra element offset per MXFP8 scaling block [1x32]
const
size_
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
const
in
t
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
thread_partial_dbias
+=
partial_dbias_rowwise
[
i
*
DBIAS_BUFF_WIDTH
+
threadIdx
.
x
+
scaling_block
];
}
}
const
size_
t
dbias_stride
=
cols
;
const
size_
t
dbias_offset_Y
=
blockIdx
.
y
;
const
size_
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
size_
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
in
t
dbias_stride
=
cols
;
const
in
t
dbias_offset_Y
=
blockIdx
.
y
;
const
in
t
dbias_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
in
t
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
bool
col_out_of_bounds_dbias
=
(
dbias_offset_X
>=
cols
);
if
(
!
col_out_of_bounds_dbias
)
{
dbias_workspace
[
dbias_idx
]
=
thread_partial_dbias
;
...
...
@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
}
// namespace mxfp8_kernel
namespace
nvfp4_kernel
{
using
namespace
ptx
;
constexpr
size_t
SCALE_DIM_Y
=
32
;
constexpr
size_t
SCALE_DIM_X
=
16
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
BUFF_DIM_Y
=
32
;
constexpr
size_t
PACK_SIZE
=
8
;
constexpr
size_t
WAVES
=
SCALE_DIM_X
/
PACK_SIZE
;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__
__forceinline__
fp8e4m3
compute_decoding_scaling_factor
(
const
float
block_amax
,
const
float
S_enc
)
{
constexpr
float
rcp_6f
=
1.0
f
/
6.0
f
;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return
static_cast
<
fp8e4m3
>
(
block_amax
*
rcp_6f
*
S_enc
);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
COLWISE_SCALING
,
size_t
CHUNK_DIM_Y
,
size_t
CHUNK_DIM_X
,
size_t
THREADS_PER_CHUNK
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
cast_nvfp4_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output_rowwise
,
const
__grid_constant__
CUtensorMap
tensor_map_output_colwise
,
fp8e4m3
*
const
scales_rowwise_e4m3
,
e8m0_t
*
const
scales_colwise_e8m0
,
const
float
*
noop
,
float
*
const
amax_ptr
,
const
float
*
const
nvfp4_second_stage_scale_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride_rowwise
,
const
size_t
scale_stride_colwise
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
ROWWISE_SCALING
=
true
;
constexpr
bool
NO_ACTIVATIONS_NOT_FP32_INPUT
=
(
!
COMPUTE_ACTIVATIONS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
);
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
if
constexpr
(
!
COMPUTE_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
constexpr
size_t
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
constexpr
size_t
THREADS_X_ROWWISE
=
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
constexpr
size_t
THREADS_Y_ROWWISE
=
THREADS_PER_CHUNK
/
THREADS_X_ROWWISE
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM_Y
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
constexpr
size_t
BUFF_IN_DIM_X
=
CHUNK_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM_X
=
(
CHUNK_DIM_X
*
4
)
/
8
;
// Holds 2 elements of 4-bit size
constexpr
size_t
BUFF_IN_DIM
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
BUFF_OUT_DIM
=
BUFF_DIM_Y
*
BUFF_OUT_DIM_X
;
constexpr
size_t
STAGES
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
ITERATIONS_ROWWISE
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
&&
ROWWISE_SCALING
&&
COLWISE_SCALING
;
const
int
block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
scales_block_offset_Y_rowwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
scales_block_offset_X_rowwise
=
blockIdx
.
x
*
CHUNK_DIM_X
/
SCALE_DIM_X
;
const
int
scales_block_offset_Y_colwise
=
blockIdx
.
y
*
CHUNK_DIM_Y
/
SCALE_DIM_Y
;
const
int
scales_block_offset_X_colwise
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
int
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
int
tid_Y_colwise
=
0
;
const
int
tid_X_colwise
=
threadIdx
.
x
;
const
int
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
int
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM_X
;
const
int
thread_offset_Y_colwise
=
tid_Y_colwise
;
const
int
thread_offset_X_colwise
=
tid_X_colwise
;
// Each thread processes two adjacent elements
const
int
row_base_rowwise
=
block_offset_Y
+
thread_offset_Y_rowwise
;
const
int
row_base_colwise
=
block_offset_Y
+
thread_offset_Y_colwise
;
const
int
col_base_colwise
=
block_offset_X
+
thread_offset_X_colwise
;
const
bool
col_out_of_bounds_colwise
=
(
col_base_colwise
>=
cols
);
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
int
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
int
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
const
bool
rowwise_scale_is_within_bounds
=
scales_offset_X_rowwise
<
cols
;
const
bool
colwise_scale_is_within_bounds
=
scales_offset_X_colwise
<
cols
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
CHUNK_DIM_Y
*
(
CHUNK_DIM_X
/
SCALE_DIM_X
)
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
/
SCALE_DIM_Y
)
*
CHUNK_DIM_X
*
sizeof
(
fp8e8m0
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
constexpr
size_t
out_mem_rowwise_data
=
(
ROWWISE_SCALING
?
buff_size_aligned_out_nvfp4
:
0
);
constexpr
size_t
out_mem_colwise_data
=
(
COLWISE_SCALING
?
buff_size_aligned_out_mxfp8
:
0
);
constexpr
size_t
out_mem_rowwise_scales
=
(
ROWWISE_SCALING
?
buff_size_nvfp4_scales
:
0
);
constexpr
size_t
out_mem_colwise_scales
=
(
COLWISE_SCALING
?
buff_size_mxfp8_scales
:
0
);
extern
__shared__
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t
dshmem
=
(
base_shmem_ptr
+
TMA_SHMEM_ALIGNMENT
-
1
)
&
~
(
static_cast
<
uintptr_t
>
(
TMA_SHMEM_ALIGNMENT
-
1
));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
out_rowwise_data_sh
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
fp8e4m3
*
out_rowwise_scales_sh
=
reinterpret_cast
<
fp8e4m3
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
e8m0_t
*
out_colwise_scales_sh
=
reinterpret_cast
<
e8m0_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
int
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const
float
S_enc
=
(
nvfp4_second_stage_scale_ptr
==
nullptr
)
?
1.0
f
:
1.0
f
/
(
*
nvfp4_second_stage_scale_ptr
);
float
thread_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_PER_CHUNK
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
#pragma unroll
for
(
int
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
int
buff
=
stage
%
BUFFS_NUM
;
const
int
next_stage
=
stage
+
1
;
const
int
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
const
int
buff_offset_in
=
buff
*
BUFF_IN_DIM
;
const
int
buff_offset_out
=
buff
*
BUFF_OUT_DIM
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
int
next_buff
=
next_stage
%
BUFFS_NUM
;
const
int
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
int
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
next_buff_offset
=
next_buff
*
BUFF_IN_DIM
;
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
float
block_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
const
int
shmem_offset_base_colwise
=
buff_offset_in
+
tid_X_colwise
;
block_amax
=
0.0
f
;
float
in_compute_colwise
[
SCALE_DIM_Y
];
IType
in_colwise_IType
[
SCALE_DIM_Y
];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType
block_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
block_amax_f16
=
__hmax
(
block_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
block_amax
=
static_cast
<
float
>
(
block_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
const
int
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
+
i
>=
rows
);
const
bool
out_of_bounds
=
(
col_out_of_bounds_colwise
||
row_out_of_bounds_colwise
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
block_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
int
global_scales_offset_Y
=
scales_offset_Y_colwise
+
stage
;
const
int
global_scales_offset_X
=
scales_offset_X_colwise
;
const
int
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
if
(
colwise_scale_is_within_bounds
)
{
scales_colwise_e8m0
[
scale_idx
]
=
biased_exponent
;
}
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
// 3. Scale elements
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
float
in
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in
=
static_cast
<
float
>
(
in_colwise_IType
[
i
]);
}
else
{
in
=
in_compute_colwise
[
i
];
}
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
int
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_IN_DIM_X
;
out_colwise_data_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
if
constexpr
(
ROWWISE_SCALING
)
{
const
int
stage_rowwise_scales_offset_Y
=
stage
*
BUFF_DIM_Y
;
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS_ROWWISE
;
++
it
)
{
const
int
it_thread_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
const
int
shmem_offset_base_rowwise_in
=
buff_offset_in
+
it_thread_offset_Y_rowwise
*
BUFF_IN_DIM_X
;
const
int
shmem_offset_base_rowwise_out
=
buff_offset_out
+
it_thread_offset_Y_rowwise
*
BUFF_OUT_DIM_X
;
const
int
it_offset_Y
=
stage_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
block_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM_X
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
(
!
out_of_bounds
)
{
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
block_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_in
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
elt
=
OP
(
elt
,
{});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
it_offset_Y
>=
rows
);
const
bool
swizzled_col_out_of_bounds
=
(
block_offset_X
+
swizzled_thread_idx
>=
cols
);
const
bool
out_of_bounds
=
(
row_out_of_bounds_rowwise
||
swizzled_col_out_of_bounds
);
if
(
!
out_of_bounds
)
{
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
}
else
{
// If no activation, elt is 0 so we can safely do this
block_amax
=
fmaxf
(
block_amax
,
fabsf
(
elt
));
}
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E4M3 scaling factor
const
fp8e4m3
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc
);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if
(
rowwise_scale_is_within_bounds
)
{
const
int
scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
const
int
scales_offset_X
=
scales_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y
*
scale_stride_rowwise
+
scales_offset_X
;
scales_rowwise_e4m3
[
scale_idx_global
]
=
S_dec_b_fp8
;
}
#else
const
int
shmem_scales_offset_Y
=
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
+
tid_Y_rowwise
;
const
int
shmem_scales_offset_X
=
tid_X_rowwise
;
const
int
scale_idx
=
shmem_scales_offset_Y
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
+
shmem_scales_offset_X
;
out_rowwise_scales_sh
[
scale_idx
]
=
S_dec_b_fp8
;
#endif
// Compute "correct" per-block encoding scaling factor
const
float
block_scale_inverse
=
__fdiv_rn
(
S_enc
,
static_cast
<
float
>
(
S_dec_b_fp8
));
// S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
fp4e2m1x4
,
PACK_SIZE
/
4
>
out
;
// Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
4
;
++
e
)
{
IType2
in01
;
IType2
in23
;
if
constexpr
(
NO_ACTIVATIONS_NOT_FP32_INPUT
)
{
in01
=
in_IType
[
w
].
data
.
elt
[
2
*
e
];
in23
=
in_IType
[
w
].
data
.
elt
[
2
*
e
+
1
];
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
in01
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
];
in01
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
1
];
in23
.
x
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
2
];
in23
.
y
=
in_cached
[
w
].
data
.
elt
[
4
*
e
+
3
];
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
4
*
e
;
in01
.
x
=
in_compute_rowwise
[
j
];
in01
.
y
=
in_compute_rowwise
[
j
+
1
];
in23
.
x
=
in_compute_rowwise
[
j
+
2
];
in23
.
y
=
in_compute_rowwise
[
j
+
3
];
}
fp4e2m1x4
&
out_quad
=
reinterpret_cast
<
fp4e2m1x4
&>
(
out
.
data
.
elt
[
e
]);
ptx
::
mul_cvt_4x
(
out_quad
,
in01
,
in23
,
block_scale_inverse
);
}
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
int
shmem_offset_rowwise
=
shmem_offset_base_rowwise_out
+
swizzled_idx
/
2
;
out
.
store_to
(
&
out_rowwise_data_sh
[
shmem_offset_rowwise
]);
}
}
}
__builtin_assume
(
thread_amax
>=
0
);
__builtin_assume
(
block_amax
>=
0
);
thread_amax
=
fmaxf
(
thread_amax
,
block_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
is_master_thread
)
{
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
buff_offset_nvfp4
=
buff
*
BUFF_OUT_DIM
;
const
int
buff_offset_mxfp8
=
buff
*
BUFF_IN_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_data_sh
[
buff_offset_nvfp4
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_data_sh
[
buff_offset_mxfp8
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if
constexpr
(
ROWWISE_SCALING
)
{
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const
int
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
threadIdx
.
x
;
const
int
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
;
const
int
scale_idx_global
=
scales_offset_Y_rowwise
*
scale_stride_rowwise
+
scales_offset_X_rowwise
;
const
int
scale_idx_shmem
=
threadIdx
.
x
*
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
;
if
((
threadIdx
.
x
<
CHUNK_DIM_Y
)
&&
(
scales_offset_Y_rowwise
<
rows
)
&&
(
scales_offset_X_rowwise
<
(
cols
/
SCALE_DIM_X
)))
{
using
ScalesVec_t
=
Vec
<
fp8e4m3
,
NVFP4_SCALING_FACTORS_PER_CHUNK_ROW
>
;
const
ScalesVec_t
&
scales
=
*
reinterpret_cast
<
ScalesVec_t
*>
(
&
out_rowwise_scales_sh
[
scale_idx_shmem
]);
scales
.
store_to
(
&
scales_rowwise_e4m3
[
scale_idx_global
]);
}
}
#endif
float
chunk_amax
=
0.0
f
;
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
chunk_amax
=
reduce_max
<
THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
thread_amax
,
warp_id
);
}
if
(
is_master_thread
&&
amax_ptr
!=
nullptr
)
{
atomicMaxFloat
(
amax_ptr
,
chunk_amax
);
}
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_kernel
constexpr
size_t
FP8_CHUNK_DIM_Y
=
128
;
constexpr
size_t
FP8_CHUNK_DIM_X
=
128
;
constexpr
size_t
FP8_THREADS_PER_CHUNK
=
128
;
...
...
@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
}
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
static
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
cast_fp8_1D
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
const
size_t
N
=
product
(
input
.
data
.
shape
);
const
bool
isFullTile
=
(
N
%
ELEMS_PER_BLOCK
==
0
);
...
...
@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif
}
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
nvfp4_quantize
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
cudaStream_t
stream
)
{
using
namespace
nvfp4_kernel
;
using
namespace
ptx
;
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 Output tensor must be allocated."
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
if
(
use_colwise_scaling
)
{
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Columnwise scaling tensor must be allocated"
);
}
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_PER_CHUNK
=
128
;
constexpr
size_t
BUFF_DIM_X
=
CHUNK_DIM_X
;
const
size_t
blocks_Y
=
DIVUP
(
rows
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
cols
,
CHUNK_DIM_X
);
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
size_t
block_size
=
THREADS_PER_CHUNK
;
const
size_t
scale_stride_rowwise
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_colwise
=
use_colwise_scaling
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
1
;
fp8e4m3
*
const
scales_rowwise_e4m3_ptr
=
reinterpret_cast
<
fp8e4m3
*>
(
output
->
scale_inv
.
dptr
);
e8m0_t
*
const
scales_colwise_e8m0_ptr
=
use_colwise_scaling
?
reinterpret_cast
<
e8m0_t
*>
(
output
->
columnwise_scale_inv
.
dptr
)
:
nullptr
;
const
ScalingType
scaling_type
=
use_colwise_scaling
?
ScalingType
::
BIDIMENSIONAL
:
ScalingType
::
ROWWISE
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
nvfp4_second_stage_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
scale
.
dptr
);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const
DType
output_data_type
=
use_colwise_scaling
?
output
->
columnwise_data
.
dtype
:
DType
::
kFloat8E4M3
;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output_data_type
,
OType
,
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_colwise
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output_rowwise
,
output
->
data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
use_colwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
nvfp4_kernel
::
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
OType
)
*
8
);
}
constexpr
size_t
buff_elems
=
nvfp4_kernel
::
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
nvfp4_kernel
::
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_nvfp4
=
DIVUP_TO_MULTIPLE
((
buff_elems_total
*
4
)
/
8
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out_mxfp8
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_nvfp4_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
16
*
sizeof
(
fp8e4m3
);
constexpr
size_t
buff_size_mxfp8_scales
=
(
CHUNK_DIM_Y
*
CHUNK_DIM_X
)
/
32
*
sizeof
(
e8m0_t
);
constexpr
size_t
in_mem
=
buff_size_aligned_in
;
const
size_t
out_rowwise_data_mem
=
buff_size_aligned_out_nvfp4
;
const
size_t
out_colwise_data_mem
=
use_colwise_scaling
?
buff_size_aligned_out_mxfp8
:
0
;
const
size_t
out_rowwise_scales_mem
=
buff_size_nvfp4_scales
;
const
size_t
out_colwise_scales_mem
=
use_colwise_scaling
?
buff_size_mxfp8_scales
:
0
;
const
size_t
out_mem
=
out_rowwise_data_mem
+
out_colwise_data_mem
+
out_rowwise_scales_mem
+
out_colwise_scales_mem
+
TMA_SHMEM_ALIGNMENT
;
const
size_t
dshmem_size
=
in_mem
+
out_mem
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
case
ScalingType
::
BIDIMENSIONAL
:
cudaFuncSetAttribute
(
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
cast_nvfp4_kernel
<
COMPUTE_ACTIVATIONS
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
CHUNK_DIM_Y
,
CHUNK_DIM_X
,
THREADS_PER_CHUNK
>
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
scales_rowwise_e4m3_ptr
,
scales_colwise_e8m0_ptr
,
noop_ptr
,
amax_ptr
,
nvfp4_second_stage_scale_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
);
break
;
});
// NOLINT(*)
);
// NOLINT(*)
}
namespace
detail
{
using
Empty
=
transformer_engine
::
Empty
;
...
...
@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto
dbias_tensor
=
convertNVTETensor
(
dbias
);
auto
workspace_tensor
=
convertNVTETensor
(
workspace
);
const
QuantizationConfig
*
quant_config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
quant_config
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
quant_config_cpp
?
quant_config_cpp
->
noop_tensor
:
nullptr
;
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
convertNVTETensorCheck
(
noop
))
:
Tensor
();
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensor
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
// Dispatch to quantization kernel depending on data format
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
output_tensor
->
has_columnwise_data
())
{
...
...
@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
)
{
cast_transpose
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
stream
);
cast_transpose
(
*
input_tensor
,
*
noop_tensor
,
output_tensor
,
stream
);
}
else
{
cast_transpose_fused
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
float
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
...
...
@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
}
}
else
if
(
output_tensor
->
has_data
())
{
fp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8_quantize
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
>
(
*
input_tensor
,
activation_input_tensor
,
&
noop_tensor
,
output_tensor
,
dbias_tensor
,
*
input_tensor
,
activation_input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
input_tensor
,
"input"
);
CheckOutputTensor
(
*
output_tensor
,
"output"
,
false
);
// Choose kernel
int32_t
rows
=
input_tensor
->
flat_first_dim
();
int32_t
cols
=
input_tensor
->
flat_last_dim
();
auto
dtype
=
input_tensor
->
dtype
();
bool
use_optimized_kernel
=
dtype
==
DType
::
kBFloat16
&&
rows
%
32
==
0
&&
cols
%
32
==
0
&&
output_tensor
->
has_data
();
// Launch NVFP4 quantize kernel
if
(
use_optimized_kernel
)
{
if
(
quant_config_cpp
.
nvfp4_2d_quantization
)
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
true
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
else
{
nvfp4_quantize_transpose
<
IS_ACT
,
ParamOP
,
OP
,
false
>
(
*
input_tensor
,
noop_tensor
,
output_tensor
,
&
quant_config_cpp
,
stream
);
}
}
else
{
auto
&
global_amax
=
(
output_tensor
->
amax
.
dptr
!=
nullptr
)
?
output_tensor
->
amax
:
output_tensor
->
columnwise_amax
;
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization"
);
quantize_transpose_vector_blockwise_fp4
(
/*input=*/
input_tensor
->
data
,
/*global_amax=*/
global_amax
,
/*scale_inv=*/
output_tensor
->
scale_inv
,
/*scale_inv_t=*/
output_tensor
->
columnwise_scale_inv
,
/*output=*/
output_tensor
->
data
,
/*output_t=*/
output_tensor
->
columnwise_data
,
/*epsilon=*/
0.0
f
,
/*return_identity=*/
output_tensor
->
has_data
(),
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
/*pow2_scale=*/
false
,
/*swizzled_scale=*/
false
,
/*use_stochastic_rounding=*/
quant_config_cpp
.
stochastic_rounding
,
/*rng_state=*/
quant_config_cpp
.
rng_state
,
/*use_2d_quantization=*/
quant_config_cpp
.
nvfp4_2d_quantization
,
/*noop_tensor=*/
noop_tensor
->
data
,
/*stream=*/
stream
);
}
break
;
}
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
true
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
quantize_transpose_square_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
/*noop_tensor=*/
noop_tensor
.
data
,
stream
);
/*noop_tensor=*/
noop_tensor
->
data
,
stream
);
break
;
}
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
false
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
bool
force_pow_2_scales
=
quant_config_cpp
.
force_pow_2_scales
;
float
epsilon
=
quant_config_cpp
.
amax_epsilon
;
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
quant_config_cpp
?
quant_config_cpp
->
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
:
false
;
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
...
...
@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
noop_tensor
.
data
,
stream
);
columnwise_option
,
force_pow_2_scales
,
noop_tensor
->
data
,
stream
);
break
;
}
default:
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
53fa872c
...
...
@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
...
...
@@ -28,6 +30,7 @@
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace
transformer_engine
{
...
...
@@ -337,8 +340,83 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
);
// NOLINT(*)
);
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#endif
#endif // __HIP_PLATFORM_AMD__
}
#if CUDA_VERSION >= 12080
template
<
typename
OType
>
__global__
void
__launch_bounds__
(
512
)
dequantize_fp4_kernel
(
const
void
*
const
input
,
OType
*
output
,
const
fp8e4m3
*
const
scales
,
const
float
*
const
tensor_amax
,
const
size_t
N
,
const
size_t
M
,
const
size_t
scale_stride
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
x
=
thread_idx
%
M
;
const
size_t
y
=
thread_idx
/
M
;
union
fp4vec
{
uint64_t
vec
;
fp4e2m1x4
small_vec
[
4
];
};
using
OVec
=
Vec
<
OType
,
4
>
;
const
uint64_t
*
const
input_vectorized
=
reinterpret_cast
<
const
uint64_t
*>
(
input
);
OVec
*
output_vec
=
reinterpret_cast
<
OVec
*>
(
output
);
const
size_t
my_index
=
x
+
y
*
M
;
const
size_t
my_scale_index
=
x
+
y
*
scale_stride
;
const
size_t
my_output_index
=
(
x
+
y
*
M
)
*
4
;
fp4vec
value
;
value
.
vec
=
input_vectorized
[
my_index
];
fp8e4m3
scale
=
scales
[
my_scale_index
];
float
amax
=
*
tensor_amax
;
constexpr
float
factor_inv
=
1.0
/
(
6.0
*
448.0
);
float
final_scale
=
static_cast
<
float
>
(
scale
)
*
amax
*
factor_inv
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float4
current
=
static_cast
<
float4
>
(
value
.
small_vec
[
i
]);
OVec
out
;
out
.
data
.
elt
[
0
]
=
static_cast
<
OType
>
(
current
.
x
*
final_scale
);
out
.
data
.
elt
[
1
]
=
static_cast
<
OType
>
(
current
.
y
*
final_scale
);
out
.
data
.
elt
[
2
]
=
static_cast
<
OType
>
(
current
.
z
*
final_scale
);
out
.
data
.
elt
[
3
]
=
static_cast
<
OType
>
(
current
.
w
*
final_scale
);
output_vec
[
my_output_index
+
i
]
=
out
;
}
}
#endif // CUDA_VERSION
void
fp4_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
#if CUDA_VERSION >= 12080
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
);
NVTE_CHECK
(
input
.
data
.
dtype
==
DType
::
kFloat4E2M1
,
"Input must have FP4 type."
);
NVTE_CHECK
(
is_high_precision_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
constexpr
int
FP4_BLOCK_SIZE
=
16
;
const
size_t
N
=
input
.
flat_first_dim
();
const
size_t
M
=
input
.
flat_last_dim
();
NVTE_CHECK
(
M
%
FP4_BLOCK_SIZE
==
0
,
"Last dimension of FP4 tensors needs to be divisible by "
,
FP4_BLOCK_SIZE
,
", but got "
,
input
.
data
.
shape
,
"."
);
const
size_t
Mread
=
M
/
FP4_BLOCK_SIZE
;
const
size_t
total
=
N
*
Mread
;
const
size_t
threads
=
512
;
const
size_t
blocks
=
DIVUP
(
total
,
threads
);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
dequantize_fp4_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
.
data
.
dptr
,
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
fp8e4m3
*>
(
input
.
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
input
.
amax
.
dptr
),
N
,
Mread
,
input
.
scale_inv
.
shape
.
back
()););
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"CUDA 12.8 or higher is needed for FP4 calculation!"
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace dequantization
namespace
detail
{
...
...
@@ -347,16 +425,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor
(
input
,
"cast_input"
);
CheckOutputTensor
(
*
output
,
"cast_output"
);
if
(
is_tensor_scaling
(
input
.
scaling_mode
))
{
switch
(
input
.
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
dequantization
::
fp8_dequantize
(
input
,
output
,
stream
);
}
else
if
(
is_mxfp_scaling
(
input
.
scaling_mode
))
{
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
if
(
is_supported_by_CC_100
())
{
dequantization
::
mxfp8_dequantize
(
input
,
output
,
stream
);
}
else
{
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
}
else
{
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
break
;
}
case
NVTE_NVFP4_1D_SCALING
:
{
dequantization
::
fp4_dequantize
(
input
,
output
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
}
}
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment