Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
".circleci/unittest/vscode:/vscode.git/clone" did not exist on "8a03087ede7d5b58e6562e4d2ac78dc904303b56"
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5247 additions
and
156 deletions
+5247
-156
transformer_engine/common/cast/dispatch/quantize.cuh
transformer_engine/common/cast/dispatch/quantize.cuh
+84
-0
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
...sformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
+988
-0
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
...mer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
+7
-0
transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh
...t/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh
+804
-0
transformer_engine/common/comm_gemm/comm_gemm.cpp
transformer_engine/common/comm_gemm/comm_gemm.cpp
+35
-27
transformer_engine/common/common.h
transformer_engine/common/common.h
+11
-3
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+78
-60
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+48
-28
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+10
-9
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+3
-1
transformer_engine/common/fused_attn/utils.cu
transformer_engine/common/fused_attn/utils.cu
+3
-1
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+7
-5
transformer_engine/common/gemm/config.cpp
transformer_engine/common/gemm/config.cpp
+103
-0
transformer_engine/common/gemm/config.h
transformer_engine/common/gemm/config.h
+19
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+20
-22
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
+645
-0
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu
...hadamard_transform/graph_safe_group_hadamard_transform.cu
+586
-0
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
...safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
+1513
-0
transformer_engine/common/include/transformer_engine/activation.h
...mer_engine/common/include/transformer_engine/activation.h
+137
-0
transformer_engine/common/include/transformer_engine/cast.h
transformer_engine/common/include/transformer_engine/cast.h
+146
-0
No files found.
transformer_engine/common/cast/dispatch/quantize.cuh
View file @
9df0c4a3
...
...
@@ -18,6 +18,7 @@
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/group_quantize_mxfp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
...
...
@@ -381,6 +382,89 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
}
}
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
group_quantize_fwd_helper
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
using
namespace
detail
;
NVTEScalingMode
scaling_mode
=
nvte_grouped_tensor_scaling_mode
(
output
);
const
NVTEGroupedTensor
activation
=
nullptr
;
NVTETensor
dbias
=
nullptr
;
NVTETensor
workspace
=
nullptr
;
const
GroupedTensor
*
input_tensor
=
convertNVTEGroupedTensorCheck
(
input
);
GroupedTensor
*
output_tensor
=
convertNVTEGroupedTensorCheck
(
output
);
const
GroupedTensor
*
activations_tensor
=
convertNVTEGroupedTensor
(
activation
);
Tensor
*
dbias_tensor
=
convertNVTETensor
(
dbias
);
Tensor
*
workspace_tensor
=
convertNVTETensor
(
workspace
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Dispatch to quantization kernel depending on data format
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8
::
group_quantize
<
/*IS_DBIAS=*/
false
,
/*IS_DACT=*/
false
,
IS_ACT
,
ParamOP
,
OP
>
(
input_tensor
,
activations_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
scaling_mode
)
+
"."
);
}
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
group_quantize_bwd_helper
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
using
namespace
detail
;
NVTEScalingMode
scaling_mode
=
nvte_grouped_tensor_scaling_mode
(
output
);
const
GroupedTensor
*
grad_tensor
=
convertNVTEGroupedTensorCheck
(
grad
);
const
GroupedTensor
*
input_tensor
=
convertNVTEGroupedTensor
(
input
);
GroupedTensor
*
output_tensor
=
convertNVTEGroupedTensorCheck
(
output
);
Tensor
*
dbias_tensor
=
convertNVTETensor
(
dbias
);
Tensor
*
workspace_tensor
=
convertNVTETensor
(
workspace
);
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Dispatch to quantization kernel depending on data format
switch
(
scaling_mode
)
{
case
NVTE_MXFP8_1D_SCALING
:
{
mxfp8
::
group_quantize
<
IS_DBIAS
,
IS_DACT
,
/*IS_ACT=*/
false
,
ParamOP
,
OP
>
(
grad_tensor
,
input_tensor
,
noop_tensor
,
output_tensor
,
dbias_tensor
,
workspace_tensor
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
scaling_mode
)
+
"."
);
}
}
}
// namespace dispatch
}
// namespace transformer_engine
...
...
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file group_quantize_mxfp8.cuh
* \brief CUDA kernels to quantize grouped tensors to MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
mxfp8
{
namespace
group_quantize_kernel
{
constexpr
int
MAX_SUPPORTED_TENSOR_DESCRIPTORS
=
64
;
__device__
alignas
(
128
)
CUtensorMap
g_tensor_maps_input
[
MAX_SUPPORTED_TENSOR_DESCRIPTORS
];
__device__
alignas
(
128
)
CUtensorMap
g_tensor_maps_act_input
[
MAX_SUPPORTED_TENSOR_DESCRIPTORS
];
__device__
alignas
(
128
)
CUtensorMap
g_tensor_maps_output_rowwise
[
MAX_SUPPORTED_TENSOR_DESCRIPTORS
];
__device__
alignas
(
128
)
CUtensorMap
g_tensor_maps_output_colwise
[
MAX_SUPPORTED_TENSOR_DESCRIPTORS
];
enum
ShapeRepresentation
{
SAME_BOTH_DIMS
=
0
,
VARYING_FIRST_DIM
=
1
,
VARYING_LAST_DIM
=
2
,
VARYING_BOTH_DIMS
=
3
};
constexpr
size_t
SCALE_DIM_Y
=
32
;
constexpr
size_t
SCALE_DIM_X
=
32
;
constexpr
size_t
BUFFS_NUM
=
2
;
constexpr
size_t
PACK_SIZE
=
4
;
constexpr
size_t
WAVES
=
SCALE_DIM_X
/
PACK_SIZE
;
constexpr
size_t
CHUNK_DIM_Y
=
128
;
constexpr
size_t
CHUNK_DIM_X
=
128
;
constexpr
size_t
THREADS_PER_CHUNK
=
128
;
constexpr
size_t
ELTS_PER_CHUNK
=
CHUNK_DIM_Y
*
CHUNK_DIM_X
;
constexpr
size_t
THREADS_X
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
constexpr
size_t
THREADS_Y
=
THREADS_PER_CHUNK
/
THREADS_X
;
constexpr
size_t
BUFF_DIM_Y
=
THREADS_Y
;
constexpr
size_t
BUFF_DIM_X
=
CHUNK_DIM_X
;
constexpr
size_t
BUFF_DIM
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
static_assert
(
BUFF_DIM_Y
==
32
);
constexpr
size_t
STAGES
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
static_assert
(
STAGES
>=
1
);
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr
size_t
TOTAL_BANKS_WIDTH
=
(
32
*
4
)
/
1
;
// 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
size_t
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
SCALE_DIM_X
;
// 4 = 128 / 32
__device__
__forceinline__
size_t
get_current_tensor_id
(
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
current_offset
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
)
{
if
(
shape_rep
==
ShapeRepresentation
::
SAME_BOTH_DIMS
)
{
const
size_t
current_row
=
current_offset
/
last_logical_dim
;
const
size_t
rows_per_tensor
=
first_logical_dim
/
num_tensors
;
return
current_row
/
rows_per_tensor
;
}
else
{
size_t
low
=
1
;
size_t
hi
=
num_tensors
;
// [low, hi]
while
(
low
<
hi
)
{
const
size_t
mid
=
low
+
(
hi
-
low
)
/
2
;
const
size_t
mid_offset
=
static_cast
<
size_t
>
(
offsets_ptr
[
mid
]);
if
(
mid_offset
<=
current_offset
)
{
low
=
mid
+
1
;
}
else
{
hi
=
mid
;
}
}
return
low
-
1
;
}
}
__device__
__forceinline__
size_t
get_tensor_rows_num
(
const
size_t
tensor_id
,
const
ShapeRepresentation
shape_rep
,
const
size_t
first_logical_dim
,
const
int64_t
*
const
__restrict__
first_dims_ptr
,
const
size_t
num_tensors
)
{
size_t
rows_num
=
0
;
switch
(
shape_rep
)
{
case
ShapeRepresentation
::
SAME_BOTH_DIMS
:
case
ShapeRepresentation
::
VARYING_LAST_DIM
:
rows_num
=
first_logical_dim
;
break
;
case
ShapeRepresentation
::
VARYING_FIRST_DIM
:
case
ShapeRepresentation
::
VARYING_BOTH_DIMS
:
rows_num
=
static_cast
<
size_t
>
(
first_dims_ptr
[
tensor_id
]);
break
;
}
return
rows_num
;
}
__device__
__forceinline__
size_t
get_tensor_cols_num
(
const
size_t
tensor_id
,
const
ShapeRepresentation
shape_rep
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
last_dims_ptr
)
{
size_t
cols_num
=
0
;
switch
(
shape_rep
)
{
case
ShapeRepresentation
::
SAME_BOTH_DIMS
:
case
ShapeRepresentation
::
VARYING_FIRST_DIM
:
cols_num
=
last_logical_dim
;
break
;
case
ShapeRepresentation
::
VARYING_LAST_DIM
:
case
ShapeRepresentation
::
VARYING_BOTH_DIMS
:
cols_num
=
static_cast
<
size_t
>
(
last_dims_ptr
[
tensor_id
]);
break
;
}
return
cols_num
;
}
// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index
__device__
__forceinline__
void
modify_base_tensor_map
(
const
CUtensorMap
base_tensor_map
,
CUtensorMap
*
global_tensor_map
,
const
uintptr_t
global_data_ptr
,
const
size_t
global_dim_Y
,
const
size_t
global_dim_X
,
const
size_t
data_type_size_bytes
)
{
__shared__
CUtensorMap
shared_tensor_map
;
shared_tensor_map
=
base_tensor_map
;
// Copy the base tensor map into shmem
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
const
size_t
global_stride_bytes
=
global_dim_X
*
data_type_size_bytes
;
if
(
global_stride_bytes
%
TMA_GMEM_ALIGNMENT
!=
0
)
{
NVTE_DEVICE_ERROR
(
"Shape not supported, as data stride must be 16B aligned."
);
}
if
(
global_data_ptr
%
TMA_GMEM_ALIGNMENT
!=
0
)
{
NVTE_DEVICE_ERROR
(
"Tensor data pointer must be 16B aligned"
);
}
asm
volatile
(
"{
\n\t
"
".reg.b64 tensor_map_ptr;
\n\t
"
"mov.b64 tensor_map_ptr, %0;
\n\t
"
"tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1;
\n\t
"
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2;
\n\t
"
// DIM Y
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3;
\n\t
"
// DIM X
"tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4;
\n
"
"}
\n
"
::
"l"
(
reinterpret_cast
<
uintptr_t
>
(
&
shared_tensor_map
)),
"l"
(
global_data_ptr
),
"r"
(
static_cast
<
uint32_t
>
(
global_dim_Y
)),
"r"
(
static_cast
<
uint32_t
>
(
global_dim_X
)),
"l"
(
static_cast
<
uint64_t
>
(
global_stride_bytes
))
:
"memory"
);
*
global_tensor_map
=
shared_tensor_map
;
}
else
{
NVTE_DEVICE_ERROR
(
"tensormap.replace is architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
}
template
<
typename
IType
,
typename
OType
>
__global__
void
update_tma_descriptors
(
const
__grid_constant__
CUtensorMap
base_tensor_map_input
,
const
__grid_constant__
CUtensorMap
base_tensor_map_act_input
,
const
__grid_constant__
CUtensorMap
base_tensor_map_output_rowwise
,
const
__grid_constant__
CUtensorMap
base_tensor_map_output_colwise
,
const
IType
*
const
__restrict__
input_data_ptr
,
const
IType
*
const
__restrict__
act_input_data_ptr
,
const
OType
*
const
__restrict__
output_rowwise_data_ptr
,
const
OType
*
const
__restrict__
output_colwise_data_ptr
,
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
,
const
int64_t
*
const
__restrict__
first_dims_ptr
,
const
int64_t
*
const
__restrict__
last_dims_ptr
,
const
bool
rowwise
,
const
bool
colwise
,
const
bool
compute_dactivations
)
{
const
bool
leading_thread
=
(
threadIdx
.
x
==
0
);
const
size_t
tensor_id
=
blockIdx
.
x
;
const
size_t
rows
=
get_tensor_rows_num
(
tensor_id
,
shape_rep
,
first_logical_dim
,
first_dims_ptr
,
num_tensors
);
const
size_t
cols
=
get_tensor_cols_num
(
tensor_id
,
shape_rep
,
last_logical_dim
,
last_dims_ptr
);
const
size_t
offset_elts
=
offsets_ptr
[
tensor_id
];
if
(
leading_thread
&&
(
tensor_id
<
num_tensors
))
{
{
const
uintptr_t
global_data_ptr
=
reinterpret_cast
<
uintptr_t
>
(
input_data_ptr
+
offset_elts
);
modify_base_tensor_map
(
base_tensor_map_input
,
&
g_tensor_maps_input
[
tensor_id
],
global_data_ptr
,
rows
,
cols
,
sizeof
(
IType
));
}
if
(
compute_dactivations
)
{
const
uintptr_t
global_data_ptr
=
reinterpret_cast
<
uintptr_t
>
(
act_input_data_ptr
+
offset_elts
);
modify_base_tensor_map
(
base_tensor_map_act_input
,
&
g_tensor_maps_act_input
[
tensor_id
],
global_data_ptr
,
rows
,
cols
,
sizeof
(
IType
));
}
if
(
rowwise
)
{
const
uintptr_t
global_data_ptr
=
reinterpret_cast
<
uintptr_t
>
(
output_rowwise_data_ptr
+
offset_elts
);
modify_base_tensor_map
(
base_tensor_map_output_rowwise
,
&
g_tensor_maps_output_rowwise
[
tensor_id
],
global_data_ptr
,
rows
,
cols
,
sizeof
(
OType
));
}
if
(
colwise
)
{
const
uintptr_t
global_data_ptr
=
reinterpret_cast
<
uintptr_t
>
(
output_colwise_data_ptr
+
offset_elts
);
modify_base_tensor_map
(
base_tensor_map_output_colwise
,
&
g_tensor_maps_output_colwise
[
tensor_id
],
global_data_ptr
,
rows
,
cols
,
sizeof
(
OType
));
}
}
}
__device__
__forceinline__
void
fence_acquire_tensormap
(
const
CUtensorMap
*
tensor_map
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"fence.proxy.tensormap::generic.acquire.cta [%0], 128;"
::
"l"
(
tensor_map
));
#else
NVTE_DEVICE_ERROR
(
"fence_acquire_tensormap is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
ROWWISE_SCALING
,
bool
COLWISE_SCALING
,
bool
WITH_GEMM_SWIZZLED_SCALES
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
group_quantize_mxfp8_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input_static
,
const
__grid_constant__
CUtensorMap
tensor_map_act_input_static
,
const
__grid_constant__
CUtensorMap
tensor_map_output_rowwise_static
,
const
__grid_constant__
CUtensorMap
tensor_map_output_colwise_static
,
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
,
const
int64_t
*
const
__restrict__
first_dims_ptr
,
const
int64_t
*
const
__restrict__
last_dims_ptr
,
e8m0_t
*
const
__restrict__
scales_rowwise_ptr
,
e8m0_t
*
const
__restrict__
scales_colwise_ptr
,
const
float
*
__restrict__
noop
,
float
*
const
__restrict__
dbias_workspace
,
float
*
const
__restrict__
amax_ptr
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
COMPUTE_ACTIVATIONS
=
IS_DACT
||
IS_ACT
;
constexpr
bool
NO_ACTIVATIONS
=
!
COMPUTE_ACTIVATIONS
;
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
OType2
=
typename
ptx
::
FPx2
<
OType
>
;
using
transformer_engine
::
dispatch
::
mxfp8
::
swizzle
::
gemm_swizzled_scale_idx
;
if
constexpr
(
NO_ACTIVATIONS
)
{
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
}
constexpr
bool
IS_CACHED_ACT_OP
=
COMPUTE_ACTIVATIONS
&&
ROWWISE_SCALING
&&
COLWISE_SCALING
;
const
size_t
block_global_offset
=
blockIdx
.
x
*
ELTS_PER_CHUNK
;
const
size_t
tensor_id
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
block_global_offset
,
first_logical_dim
,
last_logical_dim
,
offsets_ptr
);
const
size_t
rows
=
get_tensor_rows_num
(
tensor_id
,
shape_rep
,
first_logical_dim
,
first_dims_ptr
,
num_tensors
);
const
size_t
cols
=
get_tensor_cols_num
(
tensor_id
,
shape_rep
,
last_logical_dim
,
last_dims_ptr
);
const
size_t
scale_stride_rowwise
=
DIVUP_TO_MULTIPLE
(
DIVUP
(
cols
,
static_cast
<
size_t
>
(
32
)),
4
);
const
size_t
scale_stride_colwise
=
DIVUP_TO_MULTIPLE
(
cols
,
128
);
const
bool
is_single_tensor
=
(
shape_rep
==
SAME_BOTH_DIMS
||
shape_rep
==
VARYING_FIRST_DIM
);
// grouped tensor can be treated as continuous tensor for MXFP8
const
size_t
tensor_base
=
is_single_tensor
?
0
:
static_cast
<
size_t
>
(
offsets_ptr
[
tensor_id
]);
const
CUtensorMap
&
tensor_map_input
=
is_single_tensor
?
tensor_map_input_static
:
g_tensor_maps_input
[
tensor_id
];
const
CUtensorMap
&
tensor_map_act_input
=
is_single_tensor
?
tensor_map_act_input_static
:
g_tensor_maps_act_input
[
tensor_id
];
const
CUtensorMap
&
tensor_map_output_rowwise
=
is_single_tensor
?
tensor_map_output_rowwise_static
:
g_tensor_maps_output_rowwise
[
tensor_id
];
const
CUtensorMap
&
tensor_map_output_colwise
=
is_single_tensor
?
tensor_map_output_colwise_static
:
g_tensor_maps_output_colwise
[
tensor_id
];
const
bool
leading_thread
=
(
threadIdx
.
x
==
0
);
if
(
leading_thread
&&
(
!
is_single_tensor
))
{
fence_acquire_tensormap
(
&
tensor_map_input
);
if
constexpr
(
COMPUTE_ACTIVATIONS
)
{
fence_acquire_tensormap
(
&
tensor_map_act_input
);
}
if
constexpr
(
ROWWISE_SCALING
)
{
fence_acquire_tensormap
(
&
tensor_map_output_rowwise
);
}
if
constexpr
(
COLWISE_SCALING
)
{
fence_acquire_tensormap
(
&
tensor_map_output_colwise
);
}
}
const
size_t
blocks_X_num_in_current_tensor
=
DIVUP
(
cols
,
static_cast
<
size_t
>
(
128
));
const
size_t
block_id_in_current_tensor
=
is_single_tensor
?
blockIdx
.
x
:
(
blockIdx
.
x
-
tensor_base
/
ELTS_PER_CHUNK
);
const
size_t
block_id_Y
=
block_id_in_current_tensor
/
blocks_X_num_in_current_tensor
;
const
size_t
block_id_X
=
block_id_in_current_tensor
%
blocks_X_num_in_current_tensor
;
const
size_t
block_offset_Y
=
block_id_Y
*
CHUNK_DIM_Y
;
const
size_t
block_offset_X
=
block_id_X
*
CHUNK_DIM_X
;
e8m0_t
*
const
scales_rowwise
=
scales_rowwise_ptr
+
(
is_single_tensor
?
0
:
tensor_base
/
SCALE_DIM_X
);
e8m0_t
*
const
scales_colwise
=
scales_colwise_ptr
+
(
is_single_tensor
?
0
:
tensor_base
/
SCALE_DIM_Y
);
const
size_t
scales_block_offset_Y_rowwise
=
block_id_Y
*
CHUNK_DIM_Y
;
const
size_t
scales_block_offset_X_rowwise
=
block_id_X
*
CHUNK_DIM_X
/
SCALE_DIM_X
;
const
size_t
scales_block_offset_Y_colwise
=
block_id_Y
*
CHUNK_DIM_Y
/
SCALE_DIM_Y
;
const
size_t
scales_block_offset_X_colwise
=
block_id_X
*
CHUNK_DIM_X
;
const
size_t
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X
;
const
size_t
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X
;
const
size_t
tid_Y_colwise
=
0
;
const
size_t
tid_X_colwise
=
threadIdx
.
x
;
const
size_t
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
size_t
thread_offset_X_rowwise
=
tid_X_rowwise
*
SCALE_DIM_X
;
const
size_t
scales_offset_Y_rowwise
=
scales_block_offset_Y_rowwise
+
tid_Y_rowwise
;
const
size_t
scales_offset_X_rowwise
=
scales_block_offset_X_rowwise
+
tid_X_rowwise
;
const
size_t
scales_offset_Y_colwise
=
scales_block_offset_Y_colwise
+
tid_Y_colwise
;
const
size_t
scales_offset_X_colwise
=
scales_block_offset_X_colwise
+
tid_X_colwise
;
// helps resolving bank conflicts in shmem
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
buff_elems_total
*
sizeof
(
OType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
elt_input_mem
=
buff_size_aligned_in
;
constexpr
size_t
act_input_mem
=
(
IS_DACT
?
buff_size_aligned_in
:
0
);
constexpr
size_t
in_mem
=
elt_input_mem
+
act_input_mem
;
constexpr
size_t
out_mem_rowwise
=
(
ROWWISE_SCALING
?
buff_size_aligned_out
:
0
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
extern
__shared__
unsigned
char
dynamic_shmem
[];
unsigned
char
*
dshmem
=
common
::
align_smem_ptr_per_TMA_requirements
(
dynamic_shmem
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType
*
in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
);
IType
*
act_in_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
elt_input_mem
);
OType
*
out_rowwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
);
OType
*
out_colwise_data_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
in_mem
+
out_mem_rowwise
);
IType
*
cached_act_sh
=
in_sh
;
// in_sh is used as a cache buffer
constexpr
size_t
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
float
partial_dbias_colwise
=
0.0
f
;
float
thread_dbias_rowwise
[
SCALE_DIM_X
];
if
constexpr
(
IS_DBIAS
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
SCALE_DIM_X
;
++
j
)
{
thread_dbias_rowwise
[
j
]
=
0.0
f
;
}
}
float
block_amax
=
0.0
f
;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
alignas
(
8
)
uint64_t
mbar
[
STAGES
];
initialize_barriers
<
STAGES
,
THREADS_PER_CHUNK
>
(
mbar
,
leading_thread
);
int
parity
=
0
;
if
constexpr
(
IS_DACT
)
{
copy_2d_to_sharedx2
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
&
act_in_sh
[
0
],
&
tensor_map_act_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
leading_thread
);
}
else
{
copy_2d_to_shared
(
&
in_sh
[
0
],
&
tensor_map_input
,
block_offset_X
,
block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
leading_thread
);
}
#pragma unroll
for
(
int
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
size_t
buff
=
stage
%
BUFFS_NUM
;
const
size_t
next_stage
=
stage
+
1
;
const
size_t
stage_offset_Y
=
stage
*
BUFF_DIM_Y
;
if
(
next_stage
<
STAGES
)
{
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
1
>
();
const
size_t
next_buff
=
next_stage
%
BUFFS_NUM
;
const
size_t
next_stage_offset_Y
=
next_stage
*
BUFF_DIM_Y
;
const
size_t
global_offset_Y
=
block_offset_Y
+
next_stage_offset_Y
;
const
size_t
global_offset_X
=
block_offset_X
;
const
size_t
next_buff_offset
=
next_buff
*
BUFF_DIM
;
if
constexpr
(
IS_DACT
)
{
copy_2d_to_sharedx2
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
&
act_in_sh
[
next_buff_offset
],
&
tensor_map_act_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
leading_thread
);
}
else
{
copy_2d_to_shared
(
&
in_sh
[
next_buff_offset
],
&
tensor_map_input
,
global_offset_X
,
global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
leading_thread
);
}
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
parity
);
float
thread_amax
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
const
size_t
shmem_offset_base_colwise
=
buff
*
BUFF_DIM
+
tid_X_colwise
;
thread_amax
=
0.0
f
;
float
in_compute_colwise
[
BUFF_DIM_Y
];
IType
in_colwise_IType
[
BUFF_DIM_Y
];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if
constexpr
(
NO_ACTIVATIONS
&&
(
!
IS_DBIAS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
))
{
IType
thread_amax_f16
=
static_cast
<
IType
>
(
0.0
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
BUFF_DIM_Y
;
++
i
)
{
const
size_t
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
in_colwise_IType
[
i
]
=
in_sh
[
shmem_offset_colwise
];
thread_amax_f16
=
__hmax
(
thread_amax_f16
,
__habs
(
in_colwise_IType
[
i
]));
}
thread_amax
=
static_cast
<
float
>
(
thread_amax_f16
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
BUFF_DIM_Y
;
++
i
)
{
const
size_t
shmem_offset_colwise
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
float
elt
=
static_cast
<
float
>
(
in_sh
[
shmem_offset_colwise
]);
if
constexpr
(
IS_ACT
)
{
elt
=
OP
(
elt
,
{});
}
if
constexpr
(
IS_DACT
)
{
float
act_in_elt
=
static_cast
<
float
>
(
act_in_sh
[
shmem_offset_colwise
]);
elt
*=
OP
(
act_in_elt
,
{});
}
if
constexpr
(
IS_DBIAS
)
{
partial_dbias_colwise
+=
elt
;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if
constexpr
(
IS_CACHED_ACT_OP
)
{
cached_act_sh
[
shmem_offset_colwise
]
=
static_cast
<
IType
>
(
elt
);
}
thread_amax
=
fmaxf
(
thread_amax
,
fabsf
(
elt
));
in_compute_colwise
[
i
]
=
elt
;
}
}
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
global_scales_offset_Y
=
scales_offset_Y_colwise
+
stage
;
const
size_t
global_scales_offset_X
=
scales_offset_X_colwise
;
size_t
scale_idx
=
0
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
scale_idx
=
gemm_swizzled_scale_idx
(
global_scales_offset_X
,
global_scales_offset_Y
,
DIVUP
(
rows
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
}
scales_colwise
[
scale_idx
]
=
biased_exponent
;
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM_Y
;
++
i
)
{
float
in
;
if
constexpr
(
NO_ACTIVATIONS
&&
(
!
IS_DBIAS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
))
{
in
=
static_cast
<
float
>
(
in_colwise_IType
[
i
]);
}
else
{
in
=
in_compute_colwise
[
i
];
}
const
float
scaled_out
=
in
*
block_scale_inverse
;
const
size_t
shmem_offset_elt
=
shmem_offset_base_colwise
+
i
*
BUFF_DIM_X
;
out_colwise_data_sh
[
shmem_offset_elt
]
=
static_cast
<
OType
>
(
scaled_out
);
}
}
if
constexpr
(
ROWWISE_SCALING
)
{
const
size_t
shmem_offset_base_rowwise
=
buff
*
BUFF_DIM
+
thread_offset_Y_rowwise
*
BUFF_DIM_X
;
thread_amax
=
0.0
f
;
float
in_compute_rowwise
[
SCALE_DIM_X
];
Vec
<
IType
,
PACK_SIZE
>
in_cached
[
WAVES
];
// used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec
<
IType2
,
PACK_SIZE
/
2
>
in_IType
[
WAVES
];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if
constexpr
(
NO_ACTIVATIONS
&&
(
!
IS_DBIAS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
))
{
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_thread_idx
;
// Load elements
in_IType
[
w
].
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_IType
[
w
].
data
.
elt
[
e
]);
}
}
thread_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads
();
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_thread_idx
;
// Load cached elements
in_cached
[
w
].
load_from
(
&
cached_act_sh
[
shmem_offset_rowwise
]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if
constexpr
(
std
::
is_same_v
<
IType
,
float
>
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
thread_amax
=
fmaxf
(
thread_amax
,
fabsf
(
in_cached
[
w
].
data
.
elt
[
e
]));
}
}
else
{
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
e
+=
2
)
{
const
IType2
in_cached_2x
=
{
in_cached
[
w
].
data
.
elt
[
e
],
in_cached
[
w
].
data
.
elt
[
e
+
1
]};
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
in_cached_2x
);
}
}
}
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
thread_amax
=
static_cast
<
float
>
(
__hmax
(
__habs
(
thread_amax_2x
.
x
),
__habs
(
thread_amax_2x
.
y
)));
}
}
else
{
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_thread_idx
;
Vec
<
IType
,
PACK_SIZE
>
in
;
Vec
<
IType
,
PACK_SIZE
>
act_in
;
in
.
load_from
(
&
in_sh
[
shmem_offset_rowwise
]);
if
constexpr
(
IS_DACT
)
{
act_in
.
load_from
(
&
act_in_sh
[
shmem_offset_rowwise
]);
}
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
// Compute element
float
elt
=
static_cast
<
float
>
(
in
.
data
.
elt
[
e
]);
if
constexpr
(
IS_ACT
)
{
elt
=
OP
(
elt
,
{});
}
if
constexpr
(
IS_DACT
)
{
float
act_in_elt
=
static_cast
<
float
>
(
act_in
.
data
.
elt
[
e
]);
elt
*=
OP
(
act_in_elt
,
{});
}
// If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if
constexpr
(
IS_DBIAS
&&
(
!
COLWISE_SCALING
))
{
thread_dbias_rowwise
[
j
]
+=
elt
;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if
constexpr
(
!
std
::
is_same_v
<
IType
,
float
>
)
{
elt
=
static_cast
<
float
>
(
static_cast
<
IType
>
(
elt
));
}
thread_amax
=
fmaxf
(
thread_amax
,
fabsf
(
elt
));
in_compute_rowwise
[
j
]
=
elt
;
}
}
}
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent
=
ptx
::
float_to_e8m0
(
thread_amax
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
int
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
int
stage_scales_offset_X
=
scales_offset_X_rowwise
;
size_t
scale_idx
=
0
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
scale_idx
=
gemm_swizzled_scale_idx
(
stage_scales_offset_Y
,
stage_scales_offset_X
,
DIVUP
(
cols
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
}
scales_rowwise
[
scale_idx
]
=
biased_exponent
;
const
float
block_scale_inverse
=
ptx
::
exp2f_rcp
(
biased_exponent
);
const
ptx
::
floatx2
block_scale_inverse_2x
=
{
block_scale_inverse
,
block_scale_inverse
};
// 3. Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
Vec
<
OType2
,
PACK_SIZE
/
2
>
out
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
IType2
in
;
OType2
&
out_pair
=
reinterpret_cast
<
OType2
&>
(
out
.
data
.
elt
[
e
]);
if
constexpr
(
NO_ACTIVATIONS
&&
(
!
IS_DBIAS
)
&&
(
!
std
::
is_same_v
<
IType
,
float
>
))
{
in
=
in_IType
[
w
].
data
.
elt
[
e
];
}
else
if
constexpr
(
IS_CACHED_ACT_OP
)
{
in
.
x
=
in_cached
[
w
].
data
.
elt
[
2
*
e
];
in
.
y
=
in_cached
[
w
].
data
.
elt
[
2
*
e
+
1
];
}
else
{
const
int
j
=
w
*
PACK_SIZE
+
2
*
e
;
in
.
x
=
in_compute_rowwise
[
j
];
in
.
y
=
in_compute_rowwise
[
j
+
1
];
}
ptx
::
mul_cvt_2x
(
out_pair
,
in
,
block_scale_inverse_2x
);
}
const
size_t
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
size_t
swizzled_idx
=
swizzled_group_idx
+
thread_offset_X_rowwise
;
const
size_t
shmem_offset_rowwise
=
shmem_offset_base_rowwise
+
swizzled_idx
;
out
.
store_to
(
&
out_rowwise_data_sh
[
shmem_offset_rowwise
]);
}
}
__builtin_assume
(
block_amax
>=
0
);
__builtin_assume
(
thread_amax
>=
0
);
block_amax
=
fmaxf
(
block_amax
,
thread_amax
);
// Wait for shared memory writes to be visible to TMA engine.
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if
(
leading_thread
)
{
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
;
const
int
buff_offset
=
buff
*
BUFF_DIM
;
if
constexpr
(
ROWWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_rowwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_rowwise_data_sh
[
buff_offset
]));
}
if
constexpr
(
COLWISE_SCALING
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_colwise
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
out_colwise_data_sh
[
buff_offset
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx
::
cp_async_bulk_commit_group
();
}
}
parity
^=
1
;
if
constexpr
(
IS_DBIAS
)
{
if
(
is_single_tensor
)
{
float
thread_partial_dbias
=
0.0
f
;
if
constexpr
(
COLWISE_SCALING
)
{
thread_partial_dbias
=
partial_dbias_colwise
;
}
else
{
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float
*
partial_dbias_rowwise
=
reinterpret_cast
<
float
*>
(
dshmem
);
constexpr
int
DBIAS_BUFF_WIDTH
=
THREADS_X
*
(
SCALE_DIM_X
+
1
);
const
int
shmem_thread_offset
=
tid_Y_rowwise
*
DBIAS_BUFF_WIDTH
+
tid_X_rowwise
*
(
SCALE_DIM_X
+
1
);
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
SCALE_DIM_X
;
const
int
swizzled_group_offset
=
shmem_thread_offset
+
swizzled_group_idx
;
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
;
++
e
)
{
const
int
j
=
w
*
PACK_SIZE
+
e
;
const
int
shmem_elt_idx
=
swizzled_group_offset
+
e
;
partial_dbias_rowwise
[
shmem_elt_idx
]
=
thread_dbias_rowwise
[
j
];
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
THREADS_Y
;
++
i
)
{
// Add extra element offset per MXFP8 scaling block [1x32]
const
int
scaling_block
=
threadIdx
.
x
/
SCALE_DIM_X
;
thread_partial_dbias
+=
partial_dbias_rowwise
[
i
*
DBIAS_BUFF_WIDTH
+
threadIdx
.
x
+
scaling_block
];
}
}
const
int
dbias_stride
=
cols
;
const
int
dbias_offset_Y
=
block_id_Y
;
const
int
dbias_offset_X
=
block_id_X
*
CHUNK_DIM_X
+
threadIdx
.
x
;
const
int
dbias_idx
=
dbias_offset_Y
*
dbias_stride
+
dbias_offset_X
;
const
bool
col_out_of_bounds_dbias
=
(
dbias_offset_X
>=
cols
);
if
(
!
col_out_of_bounds_dbias
)
{
dbias_workspace
[
dbias_idx
]
=
thread_partial_dbias
;
}
}
}
if
(
amax_ptr
!=
nullptr
)
{
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
// Reduce the amax over the block
block_amax
=
reduce_max
<
THREADS_PER_CHUNK
/
THREADS_PER_WARP
>
(
block_amax
,
warp_id
);
}
if
(
leading_thread
&&
amax_ptr
!=
nullptr
)
{
atomicMaxFloat
(
amax_ptr
,
block_amax
);
}
destroy_barriers
<
STAGES
>
(
mbar
,
leading_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace group_quantize_kernel
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
group_quantize
(
const
GroupedTensor
*
input
,
const
GroupedTensor
*
activations
,
const
Tensor
*
noop
,
GroupedTensor
*
output
,
Tensor
*
dbias
,
Tensor
*
workspace
,
cudaStream_t
stream
)
{
using
namespace
group_quantize_kernel
;
checkCuDriverContext
(
stream
);
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
const
bool
use_rowwise_scaling
=
output
->
has_data
();
const
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
NVTE_CHECK
(
use_rowwise_scaling
||
use_colwise_scaling
,
"Either rowwise or columnwise output data need to be allocated."
);
ScalingType
scaling_type
=
ScalingType
::
BIDIMENSIONAL
;
if
(
!
use_colwise_scaling
)
{
scaling_type
=
ScalingType
::
ROWWISE
;
}
else
if
(
!
use_rowwise_scaling
)
{
scaling_type
=
ScalingType
::
COLWISE
;
}
ShapeRepresentation
shape_rep
=
ShapeRepresentation
::
SAME_BOTH_DIMS
;
if
(
output
->
all_same_shape
())
{
shape_rep
=
ShapeRepresentation
::
SAME_BOTH_DIMS
;
}
else
if
(
output
->
all_same_first_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_LAST_DIM
;
}
else
if
(
output
->
all_same_last_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_FIRST_DIM
;
}
else
if
(
output
->
varying_both_dims
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_BOTH_DIMS
;
}
// Treat a grouped tensor with const last dims as a single tensor
const
bool
is_single_tensor
=
(
shape_rep
==
SAME_BOTH_DIMS
||
shape_rep
==
VARYING_FIRST_DIM
);
NVTE_CHECK
(
input
->
num_tensors
==
output
->
num_tensors
,
"Number of input and output tensors must be same."
);
NVTE_CHECK
(
input
->
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
is_fp8_dtype
(
output
->
dtype
()),
"Output must have FP8 type."
);
if
(
IS_DACT
)
{
NVTE_CHECK
(
activations
->
has_data
(),
"Activations tensor must have data."
);
NVTE_CHECK
(
input
->
num_tensors
==
activations
->
num_tensors
,
"Number of grad and activations tensors must be same."
);
NVTE_CHECK
(
input
->
dtype
()
==
activations
->
dtype
(),
"Grad and activations tensors must have the same type."
);
}
const
size_t
first_logical_dim
=
input
->
logical_shape
.
data
[
0
];
const
size_t
last_logical_dim
=
input
->
logical_shape
.
data
[
1
];
const
size_t
elts_total
=
first_logical_dim
*
last_logical_dim
;
const
size_t
num_tensors
=
input
->
num_tensors
;
size_t
blocks
=
0
;
if
(
is_single_tensor
)
{
const
size_t
blocks_Y
=
DIVUP
(
first_logical_dim
,
CHUNK_DIM_Y
);
const
size_t
blocks_X
=
DIVUP
(
last_logical_dim
,
CHUNK_DIM_X
);
blocks
=
blocks_Y
*
blocks_X
;
}
else
{
NVTE_CHECK
(
num_tensors
<
MAX_SUPPORTED_TENSOR_DESCRIPTORS
,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64)."
);
// Only full tiles supported
NVTE_CHECK
(
last_logical_dim
%
CHUNK_DIM_X
==
0
,
"Last dimension of a grouped tensor should be divisible by 128."
);
blocks
=
DIVUP
(
elts_total
,
CHUNK_DIM_Y
*
CHUNK_DIM_X
);
}
const
dim3
grid
(
blocks
);
const
size_t
block_size
=
THREADS_PER_CHUNK
;
const
bool
with_gemm_swizzled_scales
=
output
->
with_gemm_swizzled_scales
;
// Logical shape of a tensor with varying all dims is [1, M*K]
if
(
shape_rep
!=
ShapeRepresentation
::
VARYING_BOTH_DIMS
)
{
NVTE_CHECK
(
first_logical_dim
%
128
==
0
,
"First dimension of a grouped tensor should be divisible by 128."
);
}
const
int64_t
*
const
offsets_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
tensor_offsets
.
dptr
);
const
int64_t
*
const
first_dims_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
first_dims
.
dptr
);
const
int64_t
*
const
last_dims_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
last_dims
.
dptr
);
float
*
const
workspace_ptr
=
IS_DBIAS
?
reinterpret_cast
<
float
*>
(
workspace
->
data
.
dptr
)
:
nullptr
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
e8m0_t
*
const
scales_rowwise_ptr
=
reinterpret_cast
<
e8m0_t
*>
(
output
->
scale_inv
.
dptr
);
e8m0_t
*
const
scales_colwise_ptr
=
reinterpret_cast
<
e8m0_t
*>
(
output
->
columnwise_scale_inv
.
dptr
);
if
(
use_rowwise_scaling
)
{
NVTE_CHECK
(
scales_rowwise_ptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
}
if
(
use_colwise_scaling
)
{
NVTE_CHECK
(
scales_colwise_ptr
!=
nullptr
,
"Columnwise scaling tensor must be allocated"
);
}
const
size_t
dbias_rows
=
DIVUP
(
first_logical_dim
,
CHUNK_DIM_Y
);
const
size_t
dbias_cols
=
last_logical_dim
;
if
constexpr
(
IS_DBIAS
)
{
NVTE_CHECK
(
is_single_tensor
,
"DBias is only supported for tensors with the const last dimension."
);
NVTE_CHECK
(
dbias
->
data
.
dtype
==
input
->
dtype
(),
"DBias must have the same type as input_tensor."
);
NVTE_CHECK
(
dbias
->
data
.
shape
==
std
::
vector
<
size_t
>
{
last_logical_dim
},
"Wrong shape of DBias."
);
NVTE_CHECK
(
workspace
!=
nullptr
,
"Workspace must be a tensor."
);
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
workspace
->
data
.
shape
=
{
dbias_rows
,
dbias_cols
};
workspace
->
data
.
dtype
=
DType
::
kFloat32
;
return
;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
input
->
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output
->
dtype
(),
OType
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
with_gemm_swizzled_scales
,
WITH_GEMM_SWIZZLED_SCALES
,
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_act_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_colwise
{};
constexpr
size_t
input_type_bit_size
=
TypeInfo
<
IType
>::
size
;
constexpr
size_t
output_type_bit_size
=
TypeInfo
<
OType
>::
size
;
create_2D_tensor_map
(
tensor_map_input
,
input
->
data
,
first_logical_dim
,
last_logical_dim
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
last_logical_dim
,
0
,
input_type_bit_size
);
if
constexpr
(
IS_DACT
)
{
create_2D_tensor_map
(
tensor_map_act_input
,
activations
->
data
,
first_logical_dim
,
last_logical_dim
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
last_logical_dim
,
0
,
input_type_bit_size
);
}
if
(
use_rowwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_rowwise
,
output
->
data
,
first_logical_dim
,
last_logical_dim
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
last_logical_dim
,
0
,
output_type_bit_size
);
}
if
(
use_colwise_scaling
)
{
create_2D_tensor_map
(
tensor_map_output_colwise
,
output
->
columnwise_data
,
first_logical_dim
,
last_logical_dim
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
last_logical_dim
,
0
,
output_type_bit_size
);
}
constexpr
size_t
buff_elems
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
size_t
buff_elems_total
=
BUFFS_NUM
*
buff_elems
;
constexpr
size_t
input_buff_size
=
(
buff_elems_total
*
input_type_bit_size
)
/
8
;
constexpr
size_t
output_buff_size
=
(
buff_elems_total
*
output_type_bit_size
)
/
8
;
constexpr
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
input_buff_size
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
output_buff_size
,
TMA_SHMEM_ALIGNMENT
);
constexpr
size_t
elt_input_mem
=
buff_size_aligned_in
;
constexpr
size_t
act_input_mem
=
(
IS_DACT
?
buff_size_aligned_in
:
0
);
constexpr
size_t
in_mem
=
elt_input_mem
+
act_input_mem
;
const
size_t
out_rowwise_mem
=
(
use_rowwise_scaling
?
buff_size_aligned_out
:
0
);
const
size_t
out_colwise_mem
=
(
use_colwise_scaling
?
buff_size_aligned_out
:
0
);
const
size_t
out_mem
=
out_rowwise_mem
+
out_colwise_mem
;
const
size_t
dshmem_size
=
in_mem
+
out_mem
+
TMA_SHMEM_ALIGNMENT
;
auto
kernel
=
group_quantize_mxfp8_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
true
,
WITH_GEMM_SWIZZLED_SCALES
>
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
{
kernel
=
group_quantize_mxfp8_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
false
,
WITH_GEMM_SWIZZLED_SCALES
>
;
break
;
}
case
ScalingType
::
COLWISE
:
{
kernel
=
group_quantize_mxfp8_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
false
,
true
,
WITH_GEMM_SWIZZLED_SCALES
>
;
break
;
}
case
ScalingType
::
BIDIMENSIONAL
:
{
kernel
=
group_quantize_mxfp8_kernel
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
ParamOP
,
OP
,
IType
,
OType
,
true
,
true
,
WITH_GEMM_SWIZZLED_SCALES
>
;
break
;
}
}
// Update tensor descriptors before launching the kernel
if
(
!
is_single_tensor
)
{
const
IType
*
const
input_dptr
=
reinterpret_cast
<
const
IType
*>
(
input
->
data
.
dptr
);
const
IType
*
const
act_input_dptr
=
IS_DACT
?
reinterpret_cast
<
const
IType
*>
(
activations
->
data
.
dptr
)
:
nullptr
;
OType
*
const
output_rowwise_dptr
=
use_rowwise_scaling
?
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
)
:
nullptr
;
OType
*
const
output_colwise_dptr
=
use_colwise_scaling
?
reinterpret_cast
<
OType
*>
(
output
->
columnwise_data
.
dptr
)
:
nullptr
;
update_tma_descriptors
<
IType
,
OType
><<<
num_tensors
,
32
,
0
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
input_dptr
,
act_input_dptr
,
output_rowwise_dptr
,
output_colwise_dptr
,
shape_rep
,
num_tensors
,
first_logical_dim
,
last_logical_dim
,
offsets_ptr
,
first_dims_ptr
,
last_dims_ptr
,
use_rowwise_scaling
,
use_colwise_scaling
,
IS_DACT
);
}
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
));
kernel
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_act_input
,
tensor_map_output_rowwise
,
tensor_map_output_colwise
,
shape_rep
,
num_tensors
,
first_logical_dim
,
last_logical_dim
,
offsets_ptr
,
first_dims_ptr
,
last_dims_ptr
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
noop_ptr
,
workspace_ptr
,
amax_ptr
);
if
constexpr
(
IS_DBIAS
)
{
common
::
reduce_dbias
<
IType
>
(
workspace_ptr
,
dbias
,
dbias_rows
,
dbias_cols
,
stream
);
}
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
}
}
// namespace mxfp8
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
View file @
9df0c4a3
...
...
@@ -25,6 +25,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
...
...
@@ -1163,6 +1164,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
#if FP4_TYPE_SUPPORTED
using
namespace
quantize_transpose_kernel
;
using
namespace
ptx
;
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...
...
@@ -1170,6 +1172,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
// TODO(Frank): Is there a better way to do this?
bool
return_transpose
=
output
->
has_columnwise_data
();
if
(
!
use_2d_quantization
&&
(
input
.
dtype
()
==
DType
::
kBFloat16
))
{
quantize_transpose_tuned_1D
(
input
,
noop
,
output
,
quant_config
,
stream
);
return
;
}
constexpr
bool
COMPUTE_ACTIVATIONS
=
false
;
using
ParamOP
=
Empty
;
constexpr
float
(
*
OP
)(
float
,
const
ParamOP
&
)
=
nullptr
;
...
...
transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_transpose_nvfp4_tuned_1D.cuh
* \brief Tuned kernel to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../../common.h"
#include "../../../util/math.h"
#include "../../../util/ptx.cuh"
#include "../../../utils.cuh"
#include "../core_nvfp4.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
namespace
nvfp4
{
namespace
quantize_transpose_tuned_kernel
{
using
namespace
quantization_and_transposition_SF
;
using
namespace
core
;
using
namespace
ptx
;
#if FP4_TYPE_SUPPORTED
struct
TunableConfig
{
static
constexpr
int
CHUNK_DIM_Y
=
128
;
static
constexpr
int
CHUNK_DIM_X
=
128
;
static
constexpr
int
PREFETCH_STAGES
=
1
;
static
constexpr
bool
PERSISTENT
=
false
;
};
constexpr
int
SCALE_DIM
=
16
;
// NVFP4 block (x16 elts)
constexpr
int
THREADS_NUM
=
128
;
constexpr
int
ELTS_PER_THREAD
=
16
;
constexpr
int
TILE_DIM_Y
=
64
;
constexpr
int
TILE_DIM_X
=
64
;
static_assert
(
ELTS_PER_THREAD
==
SCALE_DIM
&&
"Hardcoded and fixed parameter
\0
"
);
static_assert
((
THREADS_NUM
*
ELTS_PER_THREAD
<=
TILE_DIM_Y
*
TILE_DIM_X
)
&&
"Unbalanced threads workload
\0
"
);
static_assert
((
TunableConfig
::
CHUNK_DIM_Y
%
TILE_DIM_Y
==
0
)
&&
"Chunk size Y must be evenly divisible by the tile size Y
\0
"
);
static_assert
((
TunableConfig
::
CHUNK_DIM_X
%
TILE_DIM_X
==
0
)
&&
"Chunk size X must be evenly divisible by the tile size X
\0
"
);
static_assert
((
TILE_DIM_Y
%
SCALE_DIM
==
0
)
&&
"Tile size Y must be evenly divisible by the scale dim
\0
"
);
static_assert
((
TILE_DIM_X
%
SCALE_DIM
==
0
)
&&
"Tile size X must be evenly divisible by the scale dim
\0
"
);
constexpr
int
TILES_Y
=
TunableConfig
::
CHUNK_DIM_Y
/
TILE_DIM_Y
;
constexpr
int
TILES_X
=
TunableConfig
::
CHUNK_DIM_X
/
TILE_DIM_X
;
constexpr
int
THREADS_PER_SCALE_ROWWISE
=
SCALE_DIM
/
ELTS_PER_THREAD
;
constexpr
int
SCALES_PER_CHUNK_Y
=
TunableConfig
::
CHUNK_DIM_Y
/
SCALE_DIM
;
constexpr
int
SCALES_PER_CHUNK_X
=
TunableConfig
::
CHUNK_DIM_X
/
SCALE_DIM
;
constexpr
int
SCALES_PER_TILE_Y
=
TILE_DIM_Y
/
SCALE_DIM
;
constexpr
int
SCALES_PER_TILE_X
=
TILE_DIM_X
/
SCALE_DIM
;
constexpr
int
STAGES_Y
=
TILES_Y
;
constexpr
int
STAGES_X
=
TILES_X
;
constexpr
int
STAGES
=
STAGES_Y
*
STAGES_X
;
constexpr
int
BUFFS_NUM
=
TunableConfig
::
PREFETCH_STAGES
+
1
;
constexpr
int
BUFFS_NUM_IN
=
BUFFS_NUM
;
constexpr
int
BUFFS_NUM_OUT
=
BUFFS_NUM
;
constexpr
int
BUFFS_NUM_OUT_TR
=
2
;
constexpr
int
BUFF_DIM_Y
=
TILE_DIM_Y
;
constexpr
int
BUFF_DIM_X
=
TILE_DIM_X
;
constexpr
int
BUFF_SIZE
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
int
BUFF_SIZE_TOTAL
=
BUFF_SIZE
*
BUFFS_NUM
;
// Input buffer (BF16)
constexpr
int
BUFF_IN_DIM_Y
=
BUFF_DIM_Y
;
constexpr
int
BUFF_IN_DIM_X
=
BUFF_DIM_X
;
constexpr
int
BUFF_IN_SIZE
=
BUFF_IN_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
int
BUFF_IN_ELTS_NUM
=
BUFF_IN_DIM_Y
*
BUFF_IN_DIM_X
;
// Output buffer (NVFP4)
constexpr
int
BUFF_OUT_DIM_Y
=
BUFF_DIM_Y
;
constexpr
int
BUFF_OUT_DIM_X
=
(
BUFF_DIM_X
*
4
)
/
8
;
constexpr
int
BUFF_OUT_SIZE
=
BUFF_OUT_DIM_Y
*
BUFF_OUT_DIM_X
;
// Output transpose buffer (NVFP4)
constexpr
int
BUFF_OUT_TR_DIM_Y
=
BUFF_DIM_X
;
constexpr
int
BUFF_OUT_TR_DIM_X
=
(
BUFF_DIM_Y
*
4
)
/
8
;
constexpr
int
BUFF_OUT_TR_SIZE
=
BUFF_OUT_TR_DIM_Y
*
BUFF_OUT_TR_DIM_X
;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr
int
PACK_SIZE
=
8
;
constexpr
int
WAVES
=
ELTS_PER_THREAD
/
PACK_SIZE
;
constexpr
int
THREADS_X_ROWWISE
=
TILE_DIM_X
/
ELTS_PER_THREAD
;
constexpr
int
THREADS_Y_ROWWISE
=
THREADS_NUM
/
THREADS_X_ROWWISE
;
constexpr
int
THREADS_X_TR
=
TILE_DIM_X
/
2
;
constexpr
int
THREADS_Y_TR
=
THREADS_NUM
/
THREADS_X_TR
;
constexpr
int
ITERATIONS_NORMAL
=
BUFF_DIM_Y
/
THREADS_Y_ROWWISE
;
constexpr
int
ITERATIONS_TR
=
SCALES_PER_TILE_Y
/
THREADS_Y_TR
;
static_assert
(
ITERATIONS_TR
>=
1
&&
"Number of transpose iterations should be >=1
\0
"
);
static_assert
((
SCALES_PER_TILE_Y
%
THREADS_Y_TR
==
0
)
&&
"Partial transpose iterations are not supported
\0
"
);
constexpr
int
BUFF_OUT_IT_OFFSET
=
BUFF_OUT_TR_DIM_X
/
ITERATIONS_TR
/
STAGES
;
static_assert
(
BUFF_DIM_Y
>=
SCALE_DIM
&&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block
\0
"
);
static_assert
(
TunableConfig
::
CHUNK_DIM_Y
>=
BUFF_DIM_Y
);
static_assert
(
BUFF_DIM_Y
>=
THREADS_Y_ROWWISE
&&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension
\0
"
);
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr
int
TOTAL_BANKS_WIDTH
=
(
32
*
4
*
8
)
/
4
;
// 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr
int
THREADS_PER_BANK
=
TOTAL_BANKS_WIDTH
/
ELTS_PER_THREAD
;
using
IType
=
bf16
;
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
IType3D
=
IType
[
BUFFS_NUM_IN
][
BUFF_IN_DIM_Y
][
BUFF_IN_DIM_X
];
using
IType2x3D
=
IType2
[
BUFFS_NUM_IN
][
BUFF_IN_DIM_Y
][
BUFF_IN_DIM_X
/
2
];
using
OType2x3D
=
fp4e2m1x2
[
BUFFS_NUM_OUT
][
BUFF_OUT_DIM_Y
][
BUFF_OUT_DIM_X
];
using
OType2xt3D
=
fp4e2m1x2
[
BUFFS_NUM_OUT_TR
][
BUFF_OUT_TR_DIM_Y
][
BUFF_OUT_TR_DIM_X
];
using
ScalesType2D
=
nvfp4_scale_t
[
TunableConfig
::
CHUNK_DIM_Y
][
SCALES_PER_CHUNK_X
];
using
ScalesTypeTr2D
=
nvfp4_scale_t
[
TunableConfig
::
CHUNK_DIM_X
][
SCALES_PER_CHUNK_Y
];
using
RNG_t
=
typename
transformer_engine
::
curanddx
::
detail
::
philox4x32_native_state
<
10
>
;
template
<
bool
USE_FAST_MATH
>
struct
SCALING_COEFFICIENT_TYPE
{};
template
<
>
struct
SCALING_COEFFICIENT_TYPE
<
false
>
{
using
type
=
float
;
};
template
<
>
struct
SCALING_COEFFICIENT_TYPE
<
true
>
{
using
type
=
bf16
;
};
__device__
__forceinline__
float
get_amax_of_pair
(
const
IType2
pair
)
{
return
static_cast
<
float
>
(
__hmax
(
__habs
(
pair
.
x
),
__habs
(
pair
.
y
)));
}
// Compute "correct" per-block encoding scaling factor
template
<
typename
SF_TYPE
>
__device__
__forceinline__
SF_TYPE
compute_nvfp4_scaling_coefficient
(
const
nvfp4_scale_t
S_dec_block
,
const
float
S_enc
)
{
NVTE_DEVICE_ERROR
(
"Unsupported scaling-factor type. Only FP32 and BF16 are supported."
);
}
template
<
>
__device__
__forceinline__
float
compute_nvfp4_scaling_coefficient
<
float
>
(
const
nvfp4_scale_t
S_dec_block
,
const
float
S_enc
)
{
const
float
S_dec
=
1.0
f
/
S_enc
;
const
float
scale_rcp
=
fminf
(
1.0
f
/
(
static_cast
<
float
>
(
S_dec_block
)
*
S_dec
),
detail
::
TypeExtrema
<
float
>::
max
);
return
scale_rcp
;
}
template
<
>
__device__
__forceinline__
bf16
compute_nvfp4_scaling_coefficient
<
bf16
>
(
const
nvfp4_scale_t
S_dec_block
,
const
float
S_enc
)
{
const
float
scale_rcp
=
fminf
(
S_enc
/
(
static_cast
<
float
>
(
S_dec_block
)),
detail
::
TypeExtrema
<
bf16
>::
max
);
return
static_cast
<
bf16
>
(
scale_rcp
);
}
template
<
bool
USE_STOCHASTIC_ROUNDING
,
bool
USE_FAST_MATH
>
__device__
__forceinline__
void
colwise_scaling
(
const
IType
*
__restrict__
sIn_ptr
,
fp4e2m1x2
*
__restrict__
sOut_tr_ptr
,
nvfp4_scale_t
*
__restrict__
sSFcolwise_ptr
,
const
float
S_enc_colwise
,
const
int
stage_Y
,
const
int
stage_X
,
const
int
buff_in
,
const
int
buff_out_tr
,
RNG_t
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
using
scaling_coeff_type
=
typename
SCALING_COEFFICIENT_TYPE
<
USE_FAST_MATH
>::
type
;
const
auto
&
sIn2x
=
*
reinterpret_cast
<
const
IType2x3D
*>
(
sIn_ptr
);
auto
&
sOut_tr
=
*
reinterpret_cast
<
OType2xt3D
*>
(
sOut_tr_ptr
);
auto
&
sSFcolwise
=
*
reinterpret_cast
<
ScalesTypeTr2D
*>
(
sSFcolwise_ptr
);
const
int
warp
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
tid_Y_colwise
=
(
thread_lane
%
4
+
warp
)
%
4
;
const
int
tid_X_colwise
=
thread_lane
;
const
int
thread_offset_Y_colwise
=
tid_Y_colwise
*
SCALE_DIM
;
const
int
thread_offset_X_colwise
=
tid_X_colwise
*
2
;
const
int
in_thread_offset_Y
=
thread_offset_Y_colwise
;
const
int
in_thread_offset_X
=
thread_offset_X_colwise
/
2
;
const
int
out_tr_thread_offset_Y
=
thread_offset_X_colwise
;
const
int
out_tr_thread_offset_X
=
thread_offset_Y_colwise
/
2
;
const
int
scale_tr_offset_Y
=
(
stage_X
*
TILE_DIM_X
)
+
2
*
tid_X_colwise
;
const
int
scale_tr_offset_X
=
(
stage_Y
*
SCALES_PER_TILE_Y
)
+
tid_Y_colwise
;
__align__
(
8
)
IType
rIn
[
2
][
SCALE_DIM
];
// Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
i
=
0
;
i
<
SCALE_DIM
;
++
i
)
{
const
IType2
elt_pair
=
ptx
::
ld_shared_b32
(
&
sIn2x
[
buff_in
][
in_thread_offset_Y
+
i
][
in_thread_offset_X
]);
rIn
[
0
][
i
]
=
elt_pair
.
x
;
rIn
[
1
][
i
]
=
elt_pair
.
y
;
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
elt_pair
);
}
const
float
block_amax
[
2
]
=
{
static_cast
<
float
>
(
__habs
(
thread_amax_2x
.
x
)),
static_cast
<
float
>
(
__habs
(
thread_amax_2x
.
y
))};
#pragma unroll
for
(
int
w
=
0
;
w
<
2
;
++
w
)
{
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
[
w
],
S_enc_colwise
);
// Store scaling factors to SMEM buffer (R2S)
sSFcolwise
[
scale_tr_offset_Y
+
w
][
scale_tr_offset_X
]
=
S_dec_b_fp8
;
const
scaling_coeff_type
SFcoefficient
=
compute_nvfp4_scaling_coefficient
<
scaling_coeff_type
>
(
S_dec_b_fp8
,
S_enc_colwise
);
// Scale elements
__align__
(
8
)
uint32_t
rOut
[
SCALE_DIM
/
8
];
#pragma unroll
for
(
int
e
=
0
;
e
<
SCALE_DIM
/
8
;
++
e
)
{
const
uint64_t
elts03
=
*
reinterpret_cast
<
uint64_t
*>
(
&
rIn
[
w
][
8
*
e
]);
const
uint64_t
elts47
=
*
reinterpret_cast
<
uint64_t
*>
(
&
rIn
[
w
][
8
*
e
+
4
]);
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
const
uint32_t
rbits03
=
core
::
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
const
uint32_t
rbits47
=
core
::
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
rOut
[
e
]
=
ptx
::
mul_cvt_bf16_to_fp4_8x_stochastic_rounding
<
scaling_coeff_type
>
(
elts03
,
elts47
,
SFcoefficient
,
rbits03
,
rbits47
);
}
else
{
rOut
[
e
]
=
ptx
::
mul_cvt_bf16_to_fp4_8x_round_to_nearest
<
scaling_coeff_type
>
(
elts03
,
elts47
,
SFcoefficient
);
}
}
uint64_t
&
out_pack_16x
=
*
reinterpret_cast
<
uint64_t
*>
(
rOut
);
ptx
::
st_shared_b64
(
&
sOut_tr
[
buff_out_tr
][
out_tr_thread_offset_Y
+
w
][
out_tr_thread_offset_X
],
out_pack_16x
);
}
}
template
<
bool
USE_STOCHASTIC_ROUNDING
,
bool
USE_FAST_MATH
>
__device__
__forceinline__
void
rowwise_scaling
(
const
IType
*
__restrict__
sIn_ptr
,
fp4e2m1x2
*
__restrict__
sOut_ptr
,
nvfp4_scale_t
*
__restrict__
sSFrowwise_ptr
,
const
float
S_enc_rowwise
,
const
int
stage_Y
,
const
int
stage_X
,
const
int
buff_in
,
const
int
buff_out
,
RNG_t
&
rng
,
uint4
&
random_uint4
,
int
&
rnd_idx
)
{
using
scaling_coeff_type
=
typename
SCALING_COEFFICIENT_TYPE
<
USE_FAST_MATH
>::
type
;
const
auto
&
sIn
=
*
reinterpret_cast
<
const
IType3D
*>
(
sIn_ptr
);
auto
&
sOut
=
*
reinterpret_cast
<
OType2x3D
*>
(
sOut_ptr
);
auto
&
sSFrowwise
=
*
reinterpret_cast
<
ScalesType2D
*>
(
sSFrowwise_ptr
);
const
int
thread_lane
=
threadIdx
.
x
%
THREADS_PER_WARP
;
const
int
bank_group
=
thread_lane
/
THREADS_PER_BANK
;
const
int
tid_Y_rowwise
=
threadIdx
.
x
/
THREADS_X_ROWWISE
;
const
int
tid_X_rowwise
=
threadIdx
.
x
%
THREADS_X_ROWWISE
;
const
int
thread_offset_Y_rowwise
=
tid_Y_rowwise
;
const
int
thread_offset_X_rowwise
=
tid_X_rowwise
*
ELTS_PER_THREAD
;
const
int
SF_thread_offset_rowwise_Y
=
tid_Y_rowwise
;
const
int
SF_thread_offset_rowwise_X
=
tid_X_rowwise
/
THREADS_PER_SCALE_ROWWISE
;
const
bool
SF_storing_thread
=
(
tid_X_rowwise
%
THREADS_PER_SCALE_ROWWISE
==
0
);
const
int
stage_rowwise_scales_offset_Y
=
SF_thread_offset_rowwise_Y
+
stage_Y
*
TILE_DIM_Y
;
const
int
stage_rowwise_scales_offset_X
=
SF_thread_offset_rowwise_X
+
stage_X
*
SCALES_PER_TILE_X
;
#pragma unroll
for
(
int
it
=
0
;
it
<
ITERATIONS_NORMAL
;
++
it
)
{
const
int
it_offset_Y_rowwise
=
thread_offset_Y_rowwise
+
it
*
THREADS_Y_ROWWISE
;
__align__
(
16
)
IType2
rIn
[
WAVES
][
PACK_SIZE
/
2
];
// Read (cache) input elements (S2R). Find NVFP4-block AMAX
IType2
thread_amax_2x
=
{
static_cast
<
IType
>
(
0.0
f
),
static_cast
<
IType
>
(
0.0
f
)};
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
ELTS_PER_THREAD
;
const
int
swizzled_thread_idx
=
thread_offset_X_rowwise
+
swizzled_group_idx
;
// Load elements
__uint128_t
&
elts_8x
=
*
reinterpret_cast
<
__uint128_t
*>
(
&
rIn
[
w
]);
elts_8x
=
ptx
::
ld_shared_b128
(
&
sIn
[
buff_in
][
it_offset_Y_rowwise
][
swizzled_thread_idx
]);
#pragma unroll
for
(
int
e
=
0
;
e
<
PACK_SIZE
/
2
;
++
e
)
{
ptx
::
abs_max_2x
(
thread_amax_2x
,
thread_amax_2x
,
rIn
[
w
][
e
]);
}
}
const
float
block_amax
=
get_amax_of_pair
(
thread_amax_2x
);
const
nvfp4_scale_t
S_dec_b_fp8
=
compute_decoding_scaling_factor
(
block_amax
,
S_enc_rowwise
);
const
scaling_coeff_type
SFcoefficient
=
compute_nvfp4_scaling_coefficient
<
scaling_coeff_type
>
(
S_dec_b_fp8
,
S_enc_rowwise
);
// Store scaling factors to SMEM buffer (R2S)
if
(
SF_storing_thread
)
{
const
int
scales_offset_Y
=
stage_rowwise_scales_offset_Y
+
it
*
THREADS_Y_ROWWISE
;
const
int
scales_offset_X
=
stage_rowwise_scales_offset_X
;
sSFrowwise
[
scales_offset_Y
][
scales_offset_X
]
=
S_dec_b_fp8
;
}
// Scale elements
#pragma unroll
for
(
int
w
=
0
;
w
<
WAVES
;
++
w
)
{
const
uint64_t
elts03
=
*
reinterpret_cast
<
uint64_t
*>
(
&
rIn
[
w
][
0
]);
const
uint64_t
elts47
=
*
reinterpret_cast
<
uint64_t
*>
(
&
rIn
[
w
][
2
]);
uint32_t
out_x8
;
if
constexpr
(
USE_STOCHASTIC_ROUNDING
)
{
const
uint32_t
rbits03
=
core
::
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
const
uint32_t
rbits47
=
core
::
get_rbits
(
rng
,
random_uint4
,
rnd_idx
);
out_x8
=
ptx
::
mul_cvt_bf16_to_fp4_8x_stochastic_rounding
<
scaling_coeff_type
>
(
elts03
,
elts47
,
SFcoefficient
,
rbits03
,
rbits47
);
}
else
{
out_x8
=
ptx
::
mul_cvt_bf16_to_fp4_8x_round_to_nearest
<
scaling_coeff_type
>
(
elts03
,
elts47
,
SFcoefficient
);
}
const
int
swizzled_group_idx
=
((
w
+
bank_group
)
*
PACK_SIZE
)
%
ELTS_PER_THREAD
;
const
int
swizzled_idx
=
(
swizzled_group_idx
+
thread_offset_X_rowwise
)
/
2
;
ptx
::
st_shared_b32
(
&
sOut
[
buff_out
][
it_offset_Y_rowwise
][
swizzled_idx
],
out_x8
);
}
}
}
template
<
bool
USE_STOCHASTIC_ROUNDING
,
bool
USE_FAST_MATH
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
quantize_transpose_nvfp4_tuned_1D_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
const
__grid_constant__
CUtensorMap
tensor_map_output
,
const
__grid_constant__
CUtensorMap
tensor_map_output_t
,
nvfp4_scale_t
*
const
scales_ptr
,
nvfp4_scale_t
*
const
scales_t_ptr
,
const
float
*
noop
,
const
float
*
const
amax_rowwise_ptr
,
const
float
*
const
amax_colwise_ptr
,
const
size_t
rows
,
const
size_t
cols
,
const
size_t
scale_stride
,
const
size_t
scale_stride_t
,
const
size_t
*
rng_state
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if
(
noop
!=
nullptr
&&
noop
[
0
]
==
1.0
f
)
{
return
;
}
const
size_t
rng_sequence
=
threadIdx
.
x
+
blockIdx
.
x
*
THREADS_NUM
+
blockIdx
.
y
*
gridDim
.
x
*
THREADS_NUM
;
const
size_t
rng_seed
=
rng_state
!=
nullptr
?
rng_state
[
0
]
:
0
;
const
size_t
rng_offset
=
rng_state
!=
nullptr
?
rng_state
[
1
]
:
0
;
RNG_t
rng
;
rng
.
init
(
rng_seed
,
rng_sequence
,
rng_offset
);
uint4
random_uint4
=
USE_STOCHASTIC_ROUNDING
?
rng
.
generate4
()
:
uint4
{
0
,
0
,
0
,
0
};
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
int
rnd_idx
=
0
;
const
bool
leading_thread
=
(
threadIdx
.
x
==
0
);
constexpr
int
buff_elems
=
BUFF_DIM_Y
*
BUFF_IN_DIM_X
;
constexpr
int
buff_elems_total_in
=
BUFFS_NUM_IN
*
buff_elems
;
constexpr
int
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total_in
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
BUFFS_NUM_OUT
*
BUFF_OUT_SIZE
,
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_aligned_out_t
=
DIVUP_TO_MULTIPLE
(
BUFFS_NUM_OUT_TR
*
BUFF_OUT_TR_SIZE
,
TMA_SHMEM_ALIGNMENT
);
constexpr
int
in_mem
=
buff_size_aligned_in
;
constexpr
int
out_mem_rowwise_data
=
buff_size_aligned_out
;
constexpr
int
out_mem_colwise_data
=
RETURN_TRANSPOSE
?
buff_size_aligned_out_t
:
0
;
constexpr
int
out_mem_rowwise_scales
=
DIVUP_TO_MULTIPLE
(
TunableConfig
::
CHUNK_DIM_Y
*
SCALES_PER_CHUNK_X
*
sizeof
(
nvfp4_scale_t
),
TMA_SHMEM_ALIGNMENT
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
extern
__shared__
unsigned
char
dynamic_shmem
[];
unsigned
char
*
dshmem
=
common
::
align_smem_ptr_per_TMA_requirements
(
dynamic_shmem
);
IType
*
sIn_ptr
=
reinterpret_cast
<
IType
*>
(
dshmem
);
fp4e2m1x2
*
sOut_ptr
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
);
fp4e2m1x2
*
sOut_tr_ptr
=
reinterpret_cast
<
fp4e2m1x2
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
);
auto
&
sIn
=
*
reinterpret_cast
<
IType3D
*>
(
sIn_ptr
);
auto
&
sOut
=
*
reinterpret_cast
<
OType2x3D
*>
(
sOut_ptr
);
auto
&
sOut_tr
=
*
reinterpret_cast
<
OType2xt3D
*>
(
sOut_tr_ptr
);
nvfp4_scale_t
*
sSFrowwise_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
);
nvfp4_scale_t
*
sSFcolwise_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
dshmem
+
in_mem
+
out_mem_rowwise_data
+
out_mem_colwise_data
+
out_mem_rowwise_scales
);
auto
&
sSFrowwise
=
*
reinterpret_cast
<
ScalesType2D
*>
(
sSFrowwise_ptr
);
auto
&
sSFcolwise
=
*
reinterpret_cast
<
ScalesTypeTr2D
*>
(
sSFcolwise_ptr
);
constexpr
int
shmem_buff_size
=
buff_size_aligned_in
/
BUFFS_NUM
;
// Compute a global encoding/decoding scaling factors for all S_dec_b
const
float
S_enc_rowwise
=
(
amax_rowwise_ptr
==
nullptr
)
?
1.0
f
:
core
::
compute_global_encode_scaling_factor_FP4
(
*
amax_rowwise_ptr
);
const
float
S_enc_colwise
=
(
amax_colwise_ptr
==
nullptr
)
?
S_enc_rowwise
:
core
::
compute_global_encode_scaling_factor_FP4
(
*
amax_colwise_ptr
);
__shared__
uint64_t
workID_mbar
;
__shared__
__uint128_t
workID_response
;
constexpr
uint32_t
workID_response_size
=
sizeof
(
workID_response
);
static_assert
(
workID_response_size
==
16
);
__shared__
uint64_t
IN_buff_readable_mbar
[
BUFFS_NUM
];
// Coordinates of the first chunk (CTA) to process
int32_t
ctaid_X
=
blockIdx
.
x
;
int32_t
ctaid_Y
=
blockIdx
.
y
;
// Initialize shared memory barriers with the number of threads participating in them
if
(
leading_thread
)
{
#pragma unroll
for
(
int
buff
=
0
;
buff
<
BUFFS_NUM
;
++
buff
)
{
ptx
::
mbarrier_init
(
&
IN_buff_readable_mbar
[
buff
],
1
);
}
ptx
::
mbarrier_init
(
&
workID_mbar
,
1
);
ptx
::
fence_proxy_async_shared_cta
();
}
__syncthreads
();
bool
job_finished
=
false
;
int
buff_in
=
0
;
int
buff_out
=
0
;
int
buff_out_tr
=
0
;
int
IN_buff_readable_parity
[
BUFFS_NUM
]
=
{
0
,
0
};
int
ctaid_parity
=
0
;
// Prefetch input data only when processing the first chunk,
// which enables the one-iteration overlap throughout the entire kernel life
#pragma unroll
for
(
int
stage
=
0
;
stage
<
TunableConfig
::
PREFETCH_STAGES
;
++
stage
)
{
const
int
buff_in
=
stage
;
const
int
stage_Y
=
stage
/
STAGES_X
;
const
int
stage_X
=
stage
%
STAGES_X
;
const
int
stage_offset_Y
=
stage_Y
*
TILE_DIM_Y
;
const
int
stage_offset_X
=
stage_X
*
TILE_DIM_X
;
const
int
block_offset_Y
=
ctaid_Y
*
TunableConfig
::
CHUNK_DIM_Y
;
const
int
block_offset_X
=
ctaid_X
*
TunableConfig
::
CHUNK_DIM_X
;
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
+
stage_offset_X
;
uint64_t
*
barrier
=
&
IN_buff_readable_mbar
[
buff_in
];
if
(
leading_thread
)
{
uint64_t
*
dst
=
reinterpret_cast
<
uint64_t
*>
(
&
sIn
[
buff_in
]);
const
uint64_t
*
src
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input
);
// Arrive on the barrier and tell how many bytes are expected to come in
ptx
::
mbarrier_arrive_expect_tx
(
barrier
,
shmem_buff_size
);
// Initiate bulk tensor copy
ptx
::
cp_async_bulk_tensor_2d_global_to_shared
(
dst
,
src
,
global_offset_X
,
global_offset_Y
,
barrier
);
}
}
while
(
!
job_finished
)
{
const
int
block_offset_Y
=
ctaid_Y
*
TunableConfig
::
CHUNK_DIM_Y
;
const
int
block_offset_X
=
ctaid_X
*
TunableConfig
::
CHUNK_DIM_X
;
const
int
block_offset_Y_tr
=
ctaid_X
*
TunableConfig
::
CHUNK_DIM_X
;
const
int
block_offset_X_tr
=
ctaid_Y
*
TunableConfig
::
CHUNK_DIM_Y
;
const
int
chunk_rows
=
rows
-
block_offset_Y
;
const
int
chunk_cols
=
cols
-
block_offset_X
;
const
int
scales_block_offset_Y_rowwise
=
ctaid_Y
*
TunableConfig
::
CHUNK_DIM_Y
;
const
int
scales_block_offset_X_rowwise
=
ctaid_X
*
SCALES_PER_CHUNK_X
;
const
int
scales_block_offset_Y_tr
=
ctaid_X
*
TunableConfig
::
CHUNK_DIM_X
;
const
int
scales_block_offset_X_tr
=
ctaid_Y
*
SCALES_PER_CHUNK_Y
;
if
constexpr
(
TunableConfig
::
PERSISTENT
)
{
if
(
leading_thread
)
{
ptx
::
mbarrier_arrive_expect_tx_cta_relaxed_shared_cta
(
&
workID_mbar
,
workID_response_size
);
ptx
::
try_cancel_cta
(
&
workID_mbar
,
&
workID_response
);
}
}
#pragma unroll
for
(
int
stage
=
0
;
stage
<
STAGES
;
++
stage
)
{
const
int
stage_Y
=
stage
/
STAGES_X
;
const
int
stage_X
=
stage
%
STAGES_X
;
const
int
stage_offset_Y
=
stage_Y
*
TILE_DIM_Y
;
const
int
stage_offset_X
=
stage_X
*
TILE_DIM_X
;
if
(
stage
==
STAGES
-
TunableConfig
::
PREFETCH_STAGES
)
{
if
constexpr
(
TunableConfig
::
PERSISTENT
)
{
ptx
::
mbarrier_wait_parity_acquire_cta_shared_cta
(
&
workID_mbar
,
ctaid_parity
);
ptx
::
get_cancelled_cta_id_2D
(
&
workID_response
,
ctaid_X
,
ctaid_Y
);
ctaid_parity
^=
1
;
}
else
{
ctaid_X
=
-
1
;
ctaid_Y
=
-
1
;
}
if
(
ctaid_X
==
-
1
&&
ctaid_Y
==
-
1
)
{
job_finished
=
true
;
}
}
// Prefetch next stage Input data
if
(
!
job_finished
||
(
stage
<
STAGES
-
TunableConfig
::
PREFETCH_STAGES
))
{
const
int
next_prefetch_buff
=
(
buff_in
+
TunableConfig
::
PREFETCH_STAGES
)
%
BUFFS_NUM
;
const
int
next_prefetch_stage
=
(
stage
+
TunableConfig
::
PREFETCH_STAGES
)
%
STAGES
;
const
int
next_prefetch_stage_Y
=
next_prefetch_stage
/
STAGES_X
;
const
int
next_prefetch_stage_X
=
next_prefetch_stage
%
STAGES_X
;
const
int
next_prefetch_stage_offset_Y
=
next_prefetch_stage_Y
*
TILE_DIM_Y
;
const
int
next_prefetch_stage_offset_X
=
next_prefetch_stage_X
*
TILE_DIM_X
;
// Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage
const
int
block_offset_Y
=
ctaid_Y
*
TunableConfig
::
CHUNK_DIM_Y
;
const
int
block_offset_X
=
ctaid_X
*
TunableConfig
::
CHUNK_DIM_X
;
const
int
global_offset_Y
=
block_offset_Y
+
next_prefetch_stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
+
next_prefetch_stage_offset_X
;
uint64_t
*
barrier
=
&
IN_buff_readable_mbar
[
next_prefetch_buff
];
if
(
leading_thread
)
{
uint64_t
*
dst
=
reinterpret_cast
<
uint64_t
*>
(
&
sIn
[
next_prefetch_buff
]);
const
uint64_t
*
src
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input
);
// Arrive on the barrier and tell how many bytes are expected to come in
ptx
::
mbarrier_arrive_expect_tx
(
barrier
,
shmem_buff_size
);
// Initiate bulk tensor copy
ptx
::
cp_async_bulk_tensor_2d_global_to_shared
(
dst
,
src
,
global_offset_X
,
global_offset_Y
,
barrier
);
}
ptx
::
fence_proxy_async_shared_cta
();
}
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity_acquire_cta_shared_cta
(
&
IN_buff_readable_mbar
[
buff_in
],
IN_buff_readable_parity
[
buff_in
]);
IN_buff_readable_parity
[
buff_in
]
^=
1
;
// Wait for TMA transfer to have finished reading shared memory
// I.e. the OUT buffer is ready to be written to
ptx
::
cp_async_bulk_wait_group_read
<
TunableConfig
::
PREFETCH_STAGES
>
();
// NVFP4 Quantization
rowwise_scaling
<
USE_STOCHASTIC_ROUNDING
,
USE_FAST_MATH
>
(
sIn_ptr
,
sOut_ptr
,
sSFrowwise_ptr
,
S_enc_rowwise
,
stage_Y
,
stage_X
,
buff_in
,
buff_out
,
rng
,
random_uint4
,
rnd_idx
);
if
constexpr
(
RETURN_TRANSPOSE
)
{
colwise_scaling
<
USE_STOCHASTIC_ROUNDING
,
USE_FAST_MATH
>
(
sIn_ptr
,
sOut_tr_ptr
,
sSFcolwise_ptr
,
S_enc_colwise
,
stage_Y
,
stage_X
,
buff_in
,
buff_out_tr
,
rng
,
random_uint4
,
rnd_idx
);
}
// Wait for shared memory writes to be visible to TMA engine
ptx
::
fence_proxy_async_shared_cta
();
__syncthreads
();
// After syncthreads, writes by all threads are visible to TMA engine
// Initiate TMA transfer to copy shared memory to global memory
if
(
leading_thread
)
{
const
int
global_offset_Y
=
block_offset_Y
+
stage_offset_Y
;
const
int
global_offset_X
=
block_offset_X
+
stage_offset_X
;
const
int
global_offset_Y_tr
=
block_offset_Y_tr
+
stage_offset_X
;
const
int
global_offset_X_tr
=
block_offset_X_tr
+
stage_offset_Y
;
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output
),
global_offset_X
,
global_offset_Y
,
reinterpret_cast
<
uint64_t
*>
(
&
sOut
[
buff_out
]));
if
constexpr
(
RETURN_TRANSPOSE
)
{
ptx
::
cp_async_bulk_tensor_2d_shared_to_global
(
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_output_t
),
global_offset_X_tr
,
global_offset_Y_tr
,
reinterpret_cast
<
uint64_t
*>
(
&
sOut_tr
[
buff_out_tr
]));
}
// Create a "bulk async-group" out of the previous bulk copy operation
ptx
::
cp_async_bulk_commit_group
();
}
buff_in
=
(
buff_in
+
1
)
%
BUFFS_NUM_IN
;
buff_out
=
(
buff_out
+
1
)
%
BUFFS_NUM_OUT
;
buff_out_tr
=
(
buff_out_tr
+
1
)
%
BUFFS_NUM_OUT_TR
;
}
// end of stages
// Vectorized store of scaling factors (S2G)
{
// Rowwise
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_X
>
;
// number of scales in X dimension of this chunk
const
int
count
=
min
(
SCALES_PER_CHUNK_X
,
chunk_cols
/
SCALE_DIM
);
for
(
size_t
row
=
threadIdx
.
x
;
row
<
TunableConfig
::
CHUNK_DIM_Y
;
row
+=
THREADS_NUM
)
{
const
size_t
row_global
=
scales_block_offset_Y_rowwise
+
row
;
if
(
row_global
<
rows
)
{
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
sSFrowwise
[
row
]);
const
size_t
scale_idx_global
=
row_global
*
scale_stride
+
scales_block_offset_X_rowwise
;
scales_vec
.
store_to_elts
(
&
scales_ptr
[
scale_idx_global
],
0
,
count
);
}
}
}
// Colwise
if
constexpr
(
RETURN_TRANSPOSE
)
{
using
ScalesVec
=
Vec
<
nvfp4_scale_t
,
SCALES_PER_CHUNK_Y
>
;
// number of scales in Y dimension of this chunk
const
int
count
=
min
(
SCALES_PER_CHUNK_Y
,
chunk_rows
/
SCALE_DIM
);
for
(
size_t
row_tr
=
threadIdx
.
x
;
row_tr
<
TunableConfig
::
CHUNK_DIM_X
;
row_tr
+=
THREADS_NUM
)
{
const
size_t
row_tr_global
=
scales_block_offset_Y_tr
+
row_tr
;
if
(
row_tr_global
<
cols
)
{
ScalesVec
&
scales_vec
=
*
reinterpret_cast
<
ScalesVec
*>
(
sSFcolwise
[
row_tr
]);
const
size_t
scale_idx_global
=
row_tr_global
*
scale_stride_t
+
scales_block_offset_X_tr
;
scales_vec
.
store_to_elts
(
&
scales_t_ptr
[
scale_idx_global
],
0
,
count
);
}
}
}
if
(
!
job_finished
)
{
// Ensures all reads from SFs buffer have completed and it's ready to be reused
__syncthreads
();
}
}
}
if
(
leading_thread
)
{
#pragma unroll
for
(
int
buff
=
0
;
buff
<
BUFFS_NUM
;
++
buff
)
{
ptx
::
mbarrier_invalid
(
&
IN_buff_readable_mbar
[
buff
]);
}
ptx
::
mbarrier_invalid
(
&
workID_mbar
);
}
#else
NVTE_DEVICE_ERROR
(
"sm_100 or higher is required."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif // FP4_TYPE_SUPPORTED
}
// namespace quantize_transpose_tuned_kernel
inline
void
quantize_transpose_tuned_1D
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if FP4_TYPE_SUPPORTED
using
namespace
quantize_transpose_tuned_kernel
;
using
namespace
ptx
;
const
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
const
bool
use_fast_math
=
quant_config
?
quant_config
->
use_fast_math
:
false
;
// If transposed output is allocated, return the transposed data
// Otherwise, it's not necesary to return the transposed data.
const
bool
return_transpose
=
output
->
has_columnwise_data
();
checkCuDriverContext
(
stream
);
CheckNoopTensor
(
*
noop
,
"cast_noop"
);
CheckInputTensor
(
input
,
"input"
);
CheckOutputTensor
(
*
output
,
"output"
,
false
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
output
->
has_data
(),
"NVFP4 output tensor must be allocated."
);
NVTE_CHECK
(
is_fp4_dtype
(
output
->
data
.
dtype
),
"Output must have FP4 type."
);
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated"
);
if
(
return_transpose
)
{
NVTE_CHECK
(
is_fp4_dtype
(
output
->
columnwise_data
.
dtype
),
"Transposed output must have FP4 type."
);
NVTE_CHECK
(
output
->
columnwise_scale_inv
.
dptr
!=
nullptr
,
"Transposed scaling tensor must be allocated"
);
}
const
size_t
rows
=
input
.
flat_first_dim
();
const
size_t
cols
=
input
.
flat_last_dim
();
NVTE_CHECK
(
rows
%
32
==
0
,
"Number of tensor rows must be a multiple of 32"
);
// 16B alignment for TMA
NVTE_CHECK
(
cols
%
32
==
0
,
"Number of tensor cols must be a multiple of 32"
);
// 16B alignment for TMA
const
int
blocks_Y
=
DIVUP
(
rows
,
static_cast
<
size_t
>
(
TunableConfig
::
CHUNK_DIM_Y
));
const
int
blocks_X
=
DIVUP
(
cols
,
static_cast
<
size_t
>
(
TunableConfig
::
CHUNK_DIM_X
));
const
dim3
grid
(
blocks_X
,
blocks_Y
);
const
int
block_size
=
THREADS_NUM
;
const
size_t
scale_stride
=
output
->
scale_inv
.
shape
[
1
];
const
size_t
scale_stride_transpose
=
return_transpose
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
0
;
nvfp4_scale_t
*
const
scales_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
scale_inv
.
dptr
);
nvfp4_scale_t
*
const
scales_transpose_ptr
=
reinterpret_cast
<
nvfp4_scale_t
*>
(
output
->
columnwise_scale_inv
.
dptr
);
const
float
*
noop_ptr
=
reinterpret_cast
<
const
float
*>
(
noop
->
data
.
dptr
);
const
float
*
const
amax_rowwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
amax
.
dptr
);
const
float
*
const
amax_colwise_ptr
=
reinterpret_cast
<
const
float
*>
(
output
->
columnwise_amax
.
dptr
);
const
NVTETensor
rng_state_tensor
=
(
quant_config
!=
nullptr
)
?
quant_config
->
rng_state
:
nullptr
;
const
size_t
*
rng_state
=
nullptr
;
if
(
rng_state_tensor
!=
nullptr
)
{
Tensor
&
rng_state_te_tensor
=
*
convertNVTETensor
(
rng_state_tensor
);
NVTE_CHECK
(
rng_state_te_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_te_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_te_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_te_tensor
.
data
.
dptr
);
}
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
alignas
(
64
)
CUtensorMap
tensor_map_output
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_transpose
{};
create_2D_tensor_map
(
tensor_map_input
,
input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
sizeof
(
IType
)
*
8
);
create_2D_tensor_map
(
tensor_map_output
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
4
);
if
(
return_transpose
)
{
create_2D_tensor_map
(
tensor_map_output_transpose
,
output
->
columnwise_data
,
cols
,
rows
,
BUFF_DIM_X
,
BUFF_DIM_Y
,
rows
,
0
,
4
);
}
constexpr
int
buff_elems
=
BUFF_DIM_Y
*
BUFF_DIM_X
;
constexpr
int
buff_elems_total_in
=
BUFFS_NUM_IN
*
buff_elems
;
constexpr
int
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
buff_elems_total_in
*
sizeof
(
IType
),
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
BUFFS_NUM_OUT
*
BUFF_OUT_SIZE
,
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_aligned_out_t
=
DIVUP_TO_MULTIPLE
(
BUFFS_NUM_OUT_TR
*
BUFF_OUT_TR_SIZE
,
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_scales
=
DIVUP_TO_MULTIPLE
(
TunableConfig
::
CHUNK_DIM_Y
*
SCALES_PER_CHUNK_X
*
sizeof
(
nvfp4_scale_t
),
TMA_SHMEM_ALIGNMENT
);
constexpr
int
buff_size_scales_transpose
=
DIVUP_TO_MULTIPLE
(
TunableConfig
::
CHUNK_DIM_X
*
SCALES_PER_CHUNK_Y
*
sizeof
(
nvfp4_scale_t
),
TMA_SHMEM_ALIGNMENT
);
const
int
in_mem
=
buff_size_aligned_in
;
const
int
out_data_mem
=
buff_size_aligned_out
;
const
int
out_data_transpose_mem
=
return_transpose
?
buff_size_aligned_out_t
:
0
;
const
int
out_scales_mem
=
buff_size_scales
;
const
int
out_scales_transpose_mem
=
return_transpose
?
buff_size_scales_transpose
:
0
;
const
int
out_mem
=
out_data_mem
+
out_data_transpose_mem
;
const
int
dshmem_size
=
in_mem
+
out_mem
+
out_scales_transpose_mem
+
out_scales_mem
+
TMA_SHMEM_ALIGNMENT
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
USE_STOCHASTIC_ROUNDING
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_fast_math
,
USE_FAST_MATH
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
RETURN_TRANSPOSE
,
{
auto
kernel
=
quantize_transpose_nvfp4_tuned_1D_kernel
<
USE_STOCHASTIC_ROUNDING
,
USE_FAST_MATH
,
RETURN_TRANSPOSE
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
dshmem_size
);
kernel
<<<
grid
,
block_size
,
dshmem_size
,
stream
>>>
(
tensor_map_input
,
tensor_map_output
,
tensor_map_output_transpose
,
scales_ptr
,
scales_transpose_ptr
,
noop_ptr
,
amax_rowwise_ptr
,
amax_colwise_ptr
,
rows
,
cols
,
scale_stride
,
scale_stride_transpose
,
rng_state
);
});););
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // FP4_TYPE_SUPPORTED
}
}
// namespace nvfp4
}
// namespace dispatch
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
transformer_engine/common/comm_gemm/comm_gemm.cpp
View file @
9df0c4a3
...
...
@@ -8,7 +8,6 @@
#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <map>
#include <memory>
...
...
@@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
ctx
->
grid_row_major
.
get
(),
ctx
->
d_desc
.
get
()));
const
cublasMpMatmulEpilogue_t
epilogue
=
CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
...
...
@@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
const
cublasOperation_t
trans_a
=
transa
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
const
cublasOperation_t
trans_b
=
transb
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA
,
&
trans_a
,
sizeof
trans_a
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB
,
&
trans_b
,
sizeof
trans_b
));
cublasMpMatmulAlgoType_t
algo_attr
=
cublasmp_algo
(
algo
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE
,
&
algo_attr
,
sizeof
algo_attr
));
const
cublasMpMatmulMatrixScale_t
scale_mode
=
CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32
;
if
(
is_fp8_dtype
(
a
->
dtype
()))
{
NVTE_CHECK
(
a
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER
,
&
a
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
b
->
dtype
()))
{
NVTE_CHECK
(
b
->
scale_inv
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER
,
&
b
->
scale_inv
.
dptr
,
sizeof
(
void
*
)));
}
if
(
is_fp8_dtype
(
d
->
dtype
()))
{
NVTE_CHECK
(
d
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER
,
&
d
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
d
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER
,
&
d
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
...
...
@@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t
epilogue
{};
size_t
size_read
{};
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Get
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Get
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
,
&
size_read
));
NVTE_CHECK
(
size_read
==
sizeof
epilogue
);
...
...
@@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
pre_act_out
?
pre_act_out
->
data
.
dptr
!=
nullptr
:
false
,
grad
});
it
!=
flags_to_epilogue
.
end
())
{
epilogue
=
static_cast
<
cublasMpMatmulEpilogue_t
>
(
epilogue
|
it
->
second
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE
,
&
epilogue
,
sizeof
epilogue
));
}
if
(
bias
&&
bias
->
data
.
dptr
)
{
cudaDataType_t
bias_type
=
get_cuda_dtype
(
bias
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
bias_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER
,
&
bias
->
data
.
dptr
,
sizeof
bias
->
data
.
dptr
));
}
if
(
pre_act_out
&&
pre_act_out
->
data
.
dptr
)
{
cudaDataType_t
aux_type
=
get_cuda_dtype
(
pre_act_out
->
data
.
dtype
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE
,
&
aux_type
,
sizeof
aux_type
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER
,
&
pre_act_out
->
data
.
dptr
,
sizeof
pre_act_out
->
data
.
dptr
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD
,
&
ldd
,
sizeof
ldd
));
if
(
is_fp8_dtype
(
pre_act_out
->
dtype
()))
{
NVTE_CHECK
(
pre_act_out
->
scale
.
dptr
,
"Scaling must be set for FP8 dtype"
);
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE
,
&
scale_mode
,
sizeof
scale_mode
));
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER
,
&
pre_act_out
->
scale
.
dptr
,
sizeof
(
void
*
)));
if
(
pre_act_out
->
amax
.
dptr
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER
,
&
pre_act_out
->
amax
.
dptr
,
sizeof
(
void
*
)));
}
...
...
@@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
}
if
(
comm_sm_count
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptorAttribute
Set
(
NVTE_CHECK_CUBLASMP
(
cublasMpMatmulDescriptor
Set
Attribute
(
ctx
->
matmul_desc
.
get
(),
CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT
,
&
comm_sm_count
,
sizeof
comm_sm_count
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpStream
Set
(
ctx
->
cublas_mp
.
get
(),
main_stream
));
NVTE_CHECK_CUBLASMP
(
cublasMpS
etS
tream
(
ctx
->
cublas_mp
.
get
(),
main_stream
));
size_t
wrksp_size_device
{};
size_t
wrksp_size_host
{};
...
...
@@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
std
::
vector
<
uint8_t
>
workspace_host
(
wrksp_size_host
);
if
(
ctx
->
workspace_size
<
wrksp_size_device
)
{
nvshmem_free
(
ctx
->
workspace
);
ctx
->
workspace
=
nvshmem_malloc
(
wrksp_size_device
);
if
(
ctx
->
workspace
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpBufferDeregister
(
ctx
->
grid_row_major
.
get
(),
ctx
->
workspace
));
NVTE_CHECK_CUBLASMP
(
cublasMpFree
(
ctx
->
grid_col_major
.
get
(),
ctx
->
workspace
));
}
NVTE_CHECK_CUBLASMP
(
cublasMpMalloc
(
ctx
->
grid_col_major
.
get
(),
&
ctx
->
workspace
,
wrksp_size_device
));
NVTE_CHECK_CUBLASMP
(
cublasMpBufferRegister
(
ctx
->
grid_row_major
.
get
(),
ctx
->
workspace
,
wrksp_size_device
));
ctx
->
workspace_size
=
wrksp_size_device
;
}
...
...
@@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
)
{
NVTE_API_CALL
(
nvte_comm_gemm_ctx_destroy
);
nvshmemx_sync_all_on_stream
(
ctx
->
stream
.
get
());
if
(
ctx
->
workspace
)
{
NVTE_CHECK_CUBLASMP
(
cublasMpBufferDeregister
(
ctx
->
grid_row_major
.
get
(),
ctx
->
workspace
));
NVTE_CHECK_CUBLASMP
(
cublasMpFree
(
ctx
->
grid_col_major
.
get
(),
ctx
->
workspace
));
}
delete
ctx
;
}
...
...
transformer_engine/common/common.h
View file @
9df0c4a3
...
...
@@ -321,6 +321,9 @@ struct GroupedTensor {
SimpleTensor
columnwise_amax
;
SimpleTensor
scale
;
// for FP8-DS only
NVTEScalingMode
scaling_mode
;
size_t
num_tensors
;
// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
...
...
@@ -338,10 +341,14 @@ struct GroupedTensor {
// Always 2D with positive dimensions
NVTEShape
logical_shape
;
NVTEScalingMode
scaling_mode
;
size_t
num_tensors
;
NVTEGroupedTensor
nvte_tensor
;
/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool
with_gemm_swizzled_scales
=
false
;
GroupedTensor
(
NVTEScalingMode
scaling_mode
,
size_t
num_tensors
)
:
data
(),
columnwise_data
(),
...
...
@@ -350,12 +357,12 @@ struct GroupedTensor {
amax
(),
columnwise_amax
(),
scale
(),
scaling_mode
(
scaling_mode
),
num_tensors
(
num_tensors
),
first_dims
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
DType
::
kInt64
),
last_dims
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
DType
::
kInt64
),
tensor_offsets
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
DType
::
kInt64
),
logical_shape
(
nvte_make_shape
(
nullptr
,
1
)),
scaling_mode
(
scaling_mode
),
nvte_tensor
(
0
)
{}
explicit
operator
NVTEGroupedTensor
()
const
noexcept
{
return
nvte_tensor
;
}
...
...
@@ -408,6 +415,7 @@ struct GroupedTensor {
num_tensors
=
0
;
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
nvte_tensor
=
0
;
with_gemm_swizzled_scales
=
false
;
}
};
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
9df0c4a3
...
...
@@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
)
{
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
,
bool
deterministic
)
{
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
...
...
@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
window_size_right
==
-
1
||
window_size_right
==
0
))
||
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
(
cudnn_runtime_version
>=
90200
&&
((
window_size_left
==
-
1
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
||
((
window_size_left
>=
0
||
window_size_left
==
-
1
)
&&
window_size_right
==
0
&&
(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
||
((
window_size_left
==
-
1
&&
window_size_right
==
-
1
&&
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_NO_MASK
)
||
((
window_size_left
==
-
1
||
window_size_left
>=
0
)
&&
window_size_right
==
0
&&
(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_NO_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_CAUSAL_MASK
||
(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_CAUSAL_BOTTOM_RIGHT_MASK
&&
max_seqlen_q
==
max_seqlen_kv
))
&&
max_seqlen_q
<=
max_seqlen_kv
&&
dropout
==
0.0
&&
...
...
@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
(
cudnn_runtime_version
>=
90600
&&
((
window_size_left
==
-
1
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
||
((
window_size_left
>=
0
||
window_size_left
==
-
1
)
&&
window_size_right
==
0
&&
((
window_size_left
>=
0
||
window_size_left
==
-
1
)
&&
(
window_size_right
>=
0
||
window_size_right
==
-
1
)
&&
((
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_CAUSAL_BOTTOM_RIGHT_MASK
&&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(
sm_arch_
<
100
||
(
sm_arch_
>=
100
&&
((
max_seqlen_q
==
max_seqlen_kv
&&
cudnn_runtime_version
<=
90700
)
||
cudnn_runtime_version
>
90700
))))
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
||
(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
&&
(
sm_arch_
<
100
||
(
sm_arch_
>=
100
&&
((
max_seqlen_q
==
max_seqlen_kv
&&
...
...
@@ -440,7 +444,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.13.1+: vanilla, off-by-one, learnable
(
cudnn_runtime_version
>=
91301
||
(
cudnn_runtime_version
<
91301
&&
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)))
{
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
))
&&
// determinism on Blackwell
// pre-9.18.1: fwd: deterministic; bwd: non-deterministic
// 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic
(
sm_arch_
<
100
||
(
sm_arch_
>=
100
&&
(
!
is_training
||
(
is_training
&&
!
deterministic
&&
(
dropout
==
0.0
||
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
))
||
(
is_training
&&
deterministic
&&
cudnn_runtime_version
>=
91801
&&
dropout
==
0.0
&&
bias_type
==
NVTE_Bias_Type
::
NVTE_NO_BIAS
)))))
{
flag_arb
=
true
;
}
if
(((
max_seqlen_q
>
512
)
||
(
max_seqlen_kv
>
512
))
&&
(
flag_arb
==
true
))
{
...
...
@@ -506,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd_qkvpacked
);
using
namespace
transformer_engine
;
...
...
@@ -553,7 +564,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
,
cuda_graph
);
cuda_graph
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -589,13 +600,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd
(
b
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
t
,
t
,
0
,
0
,
0
,
0
,
0
,
0
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
&
Q_view
,
&
K_view
,
&
V_view
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_cu_seqlens_padded
,
nullptr
,
nullptr
,
input_rng_state
,
wkspace
,
stream
,
handle
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
&
Q_view
,
&
K_view
,
&
V_view
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_cu_seqlens_padded
,
nullptr
,
nullptr
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"
\n
"
);
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
...
...
@@ -629,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd_qkvpacked
);
using
namespace
transformer_engine
;
...
...
@@ -669,7 +681,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
false
,
cuda_graph
);
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
false
,
cuda_graph
,
deterministic
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -725,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd
(
b
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
t
,
t
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
&
Q_view
,
&
K_view
,
&
V_view
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
&
dQ_view
,
&
dK_view
,
&
dV_view
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
&
Q_view
,
&
K_view
,
&
V_view
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
&
dQ_view
,
&
dK_view
,
&
dV_view
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
const
char
*
err_msg
=
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...
...
@@ -779,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd_kvpacked
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -855,7 +870,7 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
,
cuda_graph
);
return_max_logit
,
cuda_graph
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -891,13 +906,14 @@ void nvte_fused_attn_fwd_kvpacked(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
&
K_view
,
&
V_view
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
input_Q
,
&
K_view
,
&
V_view
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"
\n
"
);
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
...
...
@@ -933,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd_kvpacked
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -982,10 +998,10 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTEDType
Q_type
=
static_cast
<
NVTEDType
>
(
input_Q
->
data
.
dtype
);
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
false
,
cuda_graph
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
false
,
cuda_graph
,
deterministic
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -1040,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
&
K_view
,
&
V_view
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
&
dK_view
,
&
dV_view
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
input_Q
,
&
K_view
,
&
V_view
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
&
dK_view
,
&
dV_view
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
const
char
*
err_msg
=
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...
...
@@ -1094,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -1166,7 +1182,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
return_max_logit
,
cuda_graph
);
return_max_logit
,
cuda_graph
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -1183,13 +1199,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"
\n
"
);
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_FP8
)
{
#if (CUDNN_VERSION >= 8900)
...
...
@@ -1215,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -1262,7 +1280,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
false
,
cuda_graph
);
cuda_graph
,
deterministic
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -1289,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
input_K
,
input_V
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
bottom_right_diagonal
,
deterministic
,
input_Q
,
input_K
,
input_V
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
9df0c4a3
...
...
@@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t
max_pages_per_seq_v
,
int64_t
bias_b
,
int64_t
bias_h
,
bool
is_training
,
bool
return_max_logit
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
void
*
devPtrQ
,
void
*
devPtr
K
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrS1
,
void
*
devPtrS2
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
void
*
devPtr
Q
,
void
*
devPtrK
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrS1
,
void
*
devPtrS2
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
...
...
@@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if
(
is_bottom_right
&&
s_q
==
s_kv
&&
!
is_padding
)
{
is_causal
=
true
;
is_bottom_right
=
false
;
bottom_right_diagonal
=
false
;
}
bool
is_softmax_offset
=
(
softmax_type
!=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
);
bool
is_dropout
=
(
is_training
&&
dropout_probability
!=
0.0
f
);
...
...
@@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
true
,
tensorType
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
...
...
@@ -248,15 +250,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe
::
graph
::
SDPA_attributes
sdpa_options
;
sdpa_options
=
fe
::
graph
::
SDPA_attributes
()
.
set_name
(
"flash_attention"
)
.
set_is_inference
(
false
)
.
set_generate_stats
(
generate_stats
)
.
set_causal_mask
(
is_causal
)
.
set_causal_mask_bottom_right
(
is_bottom_right
)
.
set_attn_scale
(
attn_scale
);
fe
::
DiagonalAlignment_t
const
&
diagonal_alignment
=
bottom_right_diagonal
?
fe
::
DiagonalAlignment_t
::
BOTTOM_RIGHT
:
fe
::
DiagonalAlignment_t
::
TOP_LEFT
;
sdpa_options
.
set_diagonal_alignment
(
diagonal_alignment
);
if
(
cudnn_runtime_version
>=
90200
&&
window_size_left
!=
-
1
)
{
sdpa_options
.
set_diagonal_band_left_bound
(
window_size_left
+
1
);
}
if
(
cudnn_runtime_version
>=
90600
&&
window_size_right
!=
-
1
)
{
sdpa_options
.
set_diagonal_band_right_bound
(
window_size_right
);
}
sdpa_options
.
set_alibi_mask
(
is_alibi
);
...
...
@@ -542,13 +550,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t
max_b
,
int64_t
max_t_q
,
int64_t
max_t_kv
,
int64_t
bias_b
,
int64_t
bias_h
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
void
*
devPtrQ
,
void
*
devPtrKTranspose
,
void
*
devPtrVTranspose
,
void
*
devPtrO
,
void
*
devPtrSoftmaxStats
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrdQ
,
void
*
devPtrdK
,
void
*
devPtrdV
,
void
*
devPtrdO
,
void
*
devPtrdBias
,
void
*
devPtrdSoftmaxOffset
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
void
*
devPtrQ
,
void
*
devPtrKTranspose
,
void
*
devPtrVTranspose
,
void
*
devPtrO
,
void
*
devPtrSoftmaxStats
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrdQ
,
void
*
devPtrdK
,
void
*
devPtrdV
,
void
*
devPtrdO
,
void
*
devPtrdBias
,
void
*
devPtrdSoftmaxOffset
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
bool
is_bias
=
(
bias_type
==
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
);
...
...
@@ -563,6 +572,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if
(
is_bottom_right
&&
s_q
==
s_kv
&&
!
is_padding
)
{
is_causal
=
true
;
is_bottom_right
=
false
;
bottom_right_diagonal
=
false
;
}
bool
is_softmax_offset
=
(
softmax_type
!=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
);
bool
is_dropout
=
(
dropout_probability
!=
0.0
f
);
...
...
@@ -621,6 +631,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
tensorType
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
...
...
@@ -781,9 +792,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options
.
set_max_total_seq_len_kv
(
s_kv
);
}
fe
::
DiagonalAlignment_t
const
&
diagonal_alignment
=
bottom_right_diagonal
?
fe
::
DiagonalAlignment_t
::
BOTTOM_RIGHT
:
fe
::
DiagonalAlignment_t
::
TOP_LEFT
;
sdpa_backward_options
.
set_diagonal_alignment
(
diagonal_alignment
);
if
(
cudnn_runtime_version
>=
90200
&&
window_size_left
!=
-
1
)
{
sdpa_backward_options
.
set_diagonal_band_left_bound
(
window_size_left
+
1
);
}
if
(
cudnn_runtime_version
>=
90600
&&
window_size_right
!=
-
1
)
{
sdpa_backward_options
.
set_diagonal_band_right_bound
(
window_size_right
);
}
if
(
cudnn_runtime_version
>=
90000
)
{
sdpa_backward_options
.
set_deterministic_algorithm
(
deterministic
);
...
...
@@ -1044,8 +1063,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
...
...
@@ -1180,11 +1199,11 @@ void fused_attn_arbitrary_seqlen_fwd(
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
return_max_logit
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrS1
,
devPtrS2
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrS1
,
devPtrS2
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
@@ -1206,13 +1225,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
void
*
devPtrQ
=
input_Q
->
data
.
dptr
;
...
...
@@ -1273,8 +1293,8 @@ void fused_attn_arbitrary_seqlen_bwd(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
bias_b
,
bias_h
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
bottom_right_diagonal
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrdSoftmaxOffset
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
9df0c4a3
...
...
@@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
...
...
@@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
#endif // CUDNN_VERSION >= 8900
}
// namespace transformer_engine
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
9df0c4a3
...
...
@@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1(
0
,
0
,
true
,
true
,
qkv_tensor_type
,
o_tensor_type
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
...
...
@@ -1809,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1(
fe
::
graph
::
SDPA_fp8_attributes
sdpa_options
;
sdpa_options
=
fe
::
graph
::
SDPA_fp8_attributes
()
.
set_name
(
"sdpa_fp8"
)
.
set_
is_inference
(
fals
e
)
.
set_
generate_stats
(
tru
e
)
.
set_causal_mask
(
is_causal
)
.
set_attn_scale
(
attn_scale
);
...
...
@@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1(
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
,
0
,
0
,
true
,
false
,
qkv_tensor_type
,
o_tensor_type
,
...
...
transformer_engine/common/fused_attn/utils.cu
View file @
9df0c4a3
...
...
@@ -535,11 +535,13 @@ size_t get_max_batch_size(size_t batch_size) {
// batch size is expected to be 10s-100s
// b = 1, ..., 32 -> max_b = 32
// b = 33, ..., 512 -> max_b = next power of 2
//
otherwise
-> max_b =
b
//
b = 513, ...
-> max_b =
increment by 512
if
(
log2_b
<=
5
)
{
max_b
=
32
;
}
else
if
(
log2_b
<=
9
)
{
max_b
=
pow
(
2
,
log2_b
);
}
else
{
max_b
=
(
batch_size
+
511
)
/
512
*
512
;
}
return
max_b
;
}
...
...
transformer_engine/common/fused_attn/utils.h
View file @
9df0c4a3
...
...
@@ -111,6 +111,7 @@ struct FADescriptor_v1 {
NVTE_Softmax_Type
softmax_type
;
std
::
int64_t
window_size_left
;
std
::
int64_t
window_size_right
;
bool
bottom_right_diagonal
;
bool
deterministic
;
cudnn_frontend
::
DataType_t
qkv_tensor_type
;
cudnn_frontend
::
DataType_t
o_tensor_type
;
...
...
@@ -122,15 +123,16 @@ struct FADescriptor_v1 {
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
,
generate_max_sum_exp
)
<
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
,
generate_max_sum_exp
)
<
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
qkv_tensor
_type
,
rhs
.
o
_tensor_type
,
rhs
.
d
o_tensor_type
,
rhs
.
dqkv_tensor_type
,
rhs
.
generate_max_sum_exp
);
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
bottom_right_diagonal
,
rhs
.
deterministic
,
rhs
.
bias
_type
,
rhs
.
qkv
_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
dqkv_tensor_type
,
rhs
.
generate_max_sum_exp
);
}
};
...
...
transformer_engine/common/gemm/config.cpp
View file @
9df0c4a3
...
...
@@ -126,3 +126,106 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
delete
reinterpret_cast
<
transformer_engine
::
MatmulConfig
*>
(
config
);
}
}
NVTEGroupedMatmulConfig
nvte_create_grouped_matmul_config
()
{
return
new
transformer_engine
::
GroupedMatmulConfig
;
}
void
nvte_get_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
)
{
// Write attribute size
NVTE_CHECK
(
attr
<
kNVTEGroupedMatmulConfigNumAttributes
,
"Invalid NVTEGroupedMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
NVTE_CHECK
(
size_written
!=
nullptr
,
"Invalid size_written (got NULL)"
);
const
auto
&
attr_size
=
transformer_engine
::
GroupedMatmulConfig
::
attr_sizes
[
attr
];
*
size_written
=
attr_size
;
// Return immediately if buffer is not provided
if
(
buf
==
nullptr
)
{
return
;
}
// Check buffer size
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for grouped matmul config attribute "
"(attribute "
,
static_cast
<
int
>
(
attr
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
// Write to buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEGroupedMatmulConfig (got NULL)"
);
const
auto
&
config_
=
*
reinterpret_cast
<
const
transformer_engine
::
GroupedMatmulConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEGroupedMatmulConfigAvgM
:
{
int64_t
val
=
config_
.
avg_m
.
value_or
(
0
);
std
::
memcpy
(
buf
,
&
val
,
attr_size
);
break
;
}
case
kNVTEGroupedMatmulConfigAvgN
:
{
int64_t
val
=
config_
.
avg_n
.
value_or
(
0
);
std
::
memcpy
(
buf
,
&
val
,
attr_size
);
break
;
}
case
kNVTEGroupedMatmulConfigAvgK
:
{
int64_t
val
=
config_
.
avg_k
.
value_or
(
0
);
std
::
memcpy
(
buf
,
&
val
,
attr_size
);
break
;
}
case
kNVTEGroupedMatmulConfigSMCount
:
std
::
memcpy
(
buf
,
&
config_
.
sm_count
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEGroupedMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
void
nvte_set_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
)
{
// Check attribute and buffer
NVTE_CHECK
(
attr
<
kNVTEGroupedMatmulConfigNumAttributes
,
"Invalid NVTEGroupedMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
const
auto
&
attr_size
=
transformer_engine
::
GroupedMatmulConfig
::
attr_sizes
[
attr
];
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for grouped matmul config attribute "
"(attribute "
,
static_cast
<
int
>
(
attr
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
NVTE_CHECK
(
buf
!=
nullptr
,
"Invalid buffer (got NULL)"
);
// Read from buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEGroupedMatmulConfig (got NULL)"
);
auto
&
config_
=
*
reinterpret_cast
<
transformer_engine
::
GroupedMatmulConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEGroupedMatmulConfigAvgM
:
{
int64_t
val
;
std
::
memcpy
(
&
val
,
buf
,
attr_size
);
config_
.
avg_m
=
val
;
break
;
}
case
kNVTEGroupedMatmulConfigAvgN
:
{
int64_t
val
;
std
::
memcpy
(
&
val
,
buf
,
attr_size
);
config_
.
avg_n
=
val
;
break
;
}
case
kNVTEGroupedMatmulConfigAvgK
:
{
int64_t
val
;
std
::
memcpy
(
&
val
,
buf
,
attr_size
);
config_
.
avg_k
=
val
;
break
;
}
case
kNVTEGroupedMatmulConfigSMCount
:
std
::
memcpy
(
&
config_
.
sm_count
,
buf
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEGroupedMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
void
nvte_destroy_grouped_matmul_config
(
NVTEGroupedMatmulConfig
config
)
{
if
(
config
!=
nullptr
)
{
delete
reinterpret_cast
<
transformer_engine
::
GroupedMatmulConfig
*>
(
config
);
}
}
transformer_engine/common/gemm/config.h
View file @
9df0c4a3
...
...
@@ -9,6 +9,9 @@
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <optional>
namespace
transformer_engine
{
struct
MatmulConfig
{
...
...
@@ -31,6 +34,22 @@ struct MatmulConfig {
};
};
struct
GroupedMatmulConfig
{
// Average dimension hints for cuBLASLt algorithm selection heuristics.
// nullopt means "not set" - compute automatically from tensor shapes.
std
::
optional
<
int64_t
>
avg_m
;
std
::
optional
<
int64_t
>
avg_n
;
std
::
optional
<
int64_t
>
avg_k
;
// Number of streaming multiprocessors to use in GEMM kernel
int
sm_count
=
0
;
// Note: API transfers the value type, not std::optional
static
constexpr
size_t
attr_sizes
[]
=
{
sizeof
(
decltype
(
avg_m
)
::
value_type
),
sizeof
(
decltype
(
avg_n
)
::
value_type
),
sizeof
(
decltype
(
avg_k
)
::
value_type
),
sizeof
(
sm_count
)};
};
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
9df0c4a3
...
...
@@ -311,13 +311,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return
ret
;
}
/* cuBLAS version number at run-time */
size_t
cublas_version
()
{
// Cache version to avoid cuBLAS logging overhead
static
size_t
version
=
cublasLtGetVersion
();
return
version
;
}
}
// namespace
#endif // __HIP_PLATFORM_AMD__
...
...
@@ -518,8 +511,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
}
else
if
(
mxfp8_gemm
)
{
#if CUBLAS_VERSION >= 120800
NVTE_CHECK
(
cublas_version
()
>=
120800
,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
cublas_version
());
NVTE_CHECK
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120800
,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
transformer_engine
::
cuda
::
cublas_version
());
// Check that scales are in expected format
NVTE_CHECK
(
inputA
->
with_gemm_swizzled_scales
,
...
...
@@ -541,7 +535,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if
(
cublas_version
()
<=
120803
)
{
if
(
transformer_engine
::
cuda
::
cublas_version
()
<=
120803
)
{
const
int64_t
dummy_a_vec_stride
=
1
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE
,
&
dummy_a_vec_stride
,
...
...
@@ -553,8 +547,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
}
else
if
(
use_fp4
)
{
// NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK
(
cublas_version
()
>=
120800
,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
cublas_version
());
NVTE_CHECK
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120800
,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is "
,
transformer_engine
::
cuda
::
cublas_version
());
// Check that scales are in expected format
NVTE_CHECK
(
inputA
->
with_gemm_swizzled_scales
,
...
...
@@ -589,9 +584,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
(
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
inputB
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
))
{
#if CUBLAS_VERSION >= 120900
NVTE_CHECK
(
cublas_version
()
>=
120900
,
NVTE_CHECK
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120900
,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is "
,
cublas_version
());
transformer_engine
::
cuda
::
cublas_version
());
// Check that matrix formats are valid
NVTE_CHECK
((
!
(
inputA
->
scaling_mode
==
NVTE_BLOCK_SCALING_2D
&&
...
...
@@ -624,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
#if CUBLAS_VERSION >= 120800
if
(
cublas_version
()
>=
120800
)
{
if
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120800
)
{
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE
,
&
scaling_mode_a
,
sizeof
(
scaling_mode_a
)));
...
...
@@ -641,7 +636,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER
,
&
D_amax
,
sizeof
(
D_amax
)));
#if CUBLAS_VERSION >= 120800
if
(
cublas_version
()
>=
120800
)
{
if
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120800
)
{
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
operationDesc
,
...
...
@@ -725,12 +720,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
);
#else
NVTE_CHECK
(
cuda
::
cudart_version
()
>=
12020
&&
cuda
::
cudart_version
()
<
13000
,
NVTE_CHECK
(
transformer_engine
::
cuda
::
cudart_version
()
>=
12020
&&
transformer_engine
::
cuda
::
cudart_version
()
<
13000
,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is "
,
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
transformer_engine
::
cuda
::
cudart_version
());
NVTE_CHECK
(
transformer_engine
::
cuda
::
cublas_version
()
>=
120205
&&
transformer_engine
::
cuda
::
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is "
,
cublas_version
());
transformer_engine
::
cuda
::
cublas_version
());
if
(
m_split
==
0
)
m_split
=
1
;
if
(
n_split
==
0
)
n_split
=
1
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
...
...
@@ -1201,9 +1198,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is "
,
transformer_engine
::
cuda
::
cudart_version
());
NVTE_CHECK
(
cublas_version
()
>=
120205
&&
cublas_version
()
<
130000
,
transformer_engine
::
cuda
::
cublas_version
()
>=
120205
&&
transformer_engine
::
cuda
::
cublas_version
()
<
130000
,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is "
,
cublas_version
());
transformer_engine
::
cuda
::
cublas_version
());
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
...
...
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "./config.h"
namespace
{
inline
void
CreateCublasHandle
(
cublasLtHandle_t
*
handle
)
{
NVTE_CHECK_CUBLAS
(
cublasLtCreate
(
handle
));
}
}
// namespace
#if CUBLAS_VERSION >= 130200
namespace
{
// Helper struct to pass per-tensor shape/offset info (pointer or uniform value)
struct
TensorShapeInfo
{
const
int64_t
*
first_dims
;
// nullptr if uniform
const
int64_t
*
last_dims
;
// nullptr if uniform
const
int64_t
*
offsets
;
// nullptr if need to compute
int64_t
uniform_first
;
// used if first_dims == nullptr
int64_t
uniform_last
;
// used if last_dims == nullptr
// Create from GroupedTensor
static
TensorShapeInfo
from_tensor
(
const
transformer_engine
::
GroupedTensor
*
t
)
{
const
bool
has_first
=
t
->
first_dims
.
has_data
();
const
bool
has_last
=
t
->
last_dims
.
has_data
();
// When per-tensor dims are not provided, we must be in the uniform-shape case.
NVTE_CHECK
(
has_first
||
t
->
all_same_first_dim
(),
"GroupedTensor is missing first_dims for varying shapes"
);
NVTE_CHECK
(
has_last
||
t
->
all_same_last_dim
(),
"GroupedTensor is missing last_dims for varying shapes"
);
const
int64_t
*
first_ptr
=
has_first
?
static_cast
<
const
int64_t
*>
(
t
->
first_dims
.
dptr
)
:
nullptr
;
const
int64_t
*
last_ptr
=
has_last
?
static_cast
<
const
int64_t
*>
(
t
->
last_dims
.
dptr
)
:
nullptr
;
const
int64_t
uniform_first
=
has_first
?
0
:
static_cast
<
int64_t
>
(
t
->
get_common_first_dim
());
const
int64_t
uniform_last
=
has_last
?
0
:
static_cast
<
int64_t
>
(
t
->
get_common_last_dim
());
return
{
first_ptr
,
last_ptr
,
t
->
tensor_offsets
.
has_data
()
?
static_cast
<
const
int64_t
*>
(
t
->
tensor_offsets
.
dptr
)
:
nullptr
,
uniform_first
,
uniform_last
};
}
// Create for C tensor (uses D's dimensions, only has offsets)
static
TensorShapeInfo
create_shape_info_for_C
(
const
transformer_engine
::
GroupedTensor
*
C
,
const
transformer_engine
::
GroupedTensor
*
D
)
{
const
bool
has_first
=
D
->
first_dims
.
has_data
();
const
bool
has_last
=
D
->
last_dims
.
has_data
();
NVTE_CHECK
(
has_first
||
D
->
all_same_first_dim
(),
"GroupedTensor D is missing first_dims for varying shapes"
);
NVTE_CHECK
(
has_last
||
D
->
all_same_last_dim
(),
"GroupedTensor D is missing last_dims for varying shapes"
);
const
int64_t
*
first_ptr
=
has_first
?
static_cast
<
const
int64_t
*>
(
D
->
first_dims
.
dptr
)
:
nullptr
;
const
int64_t
*
last_ptr
=
has_last
?
static_cast
<
const
int64_t
*>
(
D
->
last_dims
.
dptr
)
:
nullptr
;
const
int64_t
uniform_first
=
has_first
?
0
:
static_cast
<
int64_t
>
(
D
->
get_common_first_dim
());
const
int64_t
uniform_last
=
has_last
?
0
:
static_cast
<
int64_t
>
(
D
->
get_common_last_dim
());
return
{
first_ptr
,
last_ptr
,
C
->
tensor_offsets
.
has_data
()
?
static_cast
<
const
int64_t
*>
(
C
->
tensor_offsets
.
dptr
)
:
nullptr
,
uniform_first
,
uniform_last
};
}
};
// Helper functions to compute average dimensions from logical_shape for heuristics
// These are hints for cuBLASLt algorithm selection, don't need to be exact
inline
int64_t
compute_avg_first_dim
(
const
transformer_engine
::
GroupedTensor
*
t
)
{
// logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first)
// In both cases, dividing by num_tensors gives the average
return
static_cast
<
int64_t
>
(
t
->
logical_shape
.
data
[
0
])
/
static_cast
<
int64_t
>
(
t
->
num_tensors
);
}
inline
int64_t
compute_avg_last_dim
(
const
transformer_engine
::
GroupedTensor
*
t
)
{
if
(
t
->
all_same_last_dim
())
{
// logical_shape[1] is the common N
return
static_cast
<
int64_t
>
(
t
->
logical_shape
.
data
[
1
]);
}
// When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division.
return
static_cast
<
int64_t
>
(
t
->
logical_shape
.
data
[
1
])
/
static_cast
<
int64_t
>
(
t
->
num_tensors
);
}
// Workspace layout for grouped GEMM
struct
GroupedGemmSetupWorkspace
{
void
**
A_ptrs
;
void
**
B_ptrs
;
void
**
C_ptrs
;
void
**
D_ptrs
;
float
**
alpha_ptrs
;
float
**
beta_ptrs
;
// Storage dimensions for cuBLAS matrix layouts
int
*
a_rows
;
int
*
a_cols
;
int
*
b_rows
;
int
*
b_cols
;
int
*
d_rows
;
// M (first dim) - also used for C
int
*
d_cols
;
// N (last dim) - also used for C
// Initialize from workspace buffer
// Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned)
static
GroupedGemmSetupWorkspace
from_buffers
(
char
*
setup_ws_ptr
,
size_t
num_tensors
)
{
GroupedGemmSetupWorkspace
ws
;
size_t
offset
=
0
;
const
size_t
ptr_size
=
num_tensors
*
sizeof
(
void
*
);
const
size_t
int_size
=
num_tensors
*
sizeof
(
int
);
// Pointer arrays first (all 8-byte aligned)
ws
.
A_ptrs
=
reinterpret_cast
<
void
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
ws
.
B_ptrs
=
reinterpret_cast
<
void
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
ws
.
C_ptrs
=
reinterpret_cast
<
void
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
ws
.
D_ptrs
=
reinterpret_cast
<
void
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
ws
.
alpha_ptrs
=
reinterpret_cast
<
float
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
ws
.
beta_ptrs
=
reinterpret_cast
<
float
**>
(
setup_ws_ptr
+
offset
);
offset
+=
ptr_size
;
// Int arrays for storage dimensions (4-byte aligned)
ws
.
a_rows
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
offset
+=
int_size
;
ws
.
a_cols
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
offset
+=
int_size
;
ws
.
b_rows
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
offset
+=
int_size
;
ws
.
b_cols
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
offset
+=
int_size
;
ws
.
d_rows
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
offset
+=
int_size
;
ws
.
d_cols
=
reinterpret_cast
<
int
*>
(
setup_ws_ptr
+
offset
);
return
ws
;
}
// Calculate required size for setup workspace
static
size_t
required_setup_size
(
size_t
num_tensors
,
size_t
alignment
)
{
const
size_t
ptr_size
=
num_tensors
*
sizeof
(
void
*
);
const
size_t
int_size
=
num_tensors
*
sizeof
(
int
);
// Layout: 6 ptr arrays, then 6 int arrays
size_t
size
=
6
*
ptr_size
+
6
*
int_size
;
size
=
((
size
+
alignment
-
1
)
/
alignment
)
*
alignment
;
return
size
;
}
};
// -----------------------------------------------------------------------------
// Helper routines to keep nvte_grouped_gemm readable
// -----------------------------------------------------------------------------
inline
void
validate_grouped_gemm_inputs
(
const
transformer_engine
::
GroupedTensor
*
inputA
,
const
transformer_engine
::
GroupedTensor
*
inputB
,
const
transformer_engine
::
GroupedTensor
*
inputC
,
const
transformer_engine
::
GroupedTensor
*
outputD
,
const
transformer_engine
::
Tensor
*
alpha_tensor
,
const
transformer_engine
::
Tensor
*
beta_tensor
)
{
const
size_t
num_tensors
=
inputA
->
num_tensors
;
NVTE_CHECK
(
num_tensors
>=
1
,
"Grouped GEMM: number of tensors must be at least 1"
);
NVTE_CHECK
(
inputB
->
num_tensors
==
num_tensors
,
"Grouped GEMM: A and B must have the same number of tensors"
);
// C can be NULL (will use D as C when beta=0)
if
(
inputC
!=
nullptr
)
{
NVTE_CHECK
(
inputC
->
num_tensors
==
num_tensors
,
"Grouped GEMM: A and C must have the same number of tensors"
);
}
NVTE_CHECK
(
outputD
->
num_tensors
==
num_tensors
,
"Grouped GEMM: A and D must have the same number of tensors"
);
// Validate alpha/beta have per-matrix values
const
size_t
alpha_numel
=
alpha_tensor
->
data
.
numel
();
const
size_t
beta_numel
=
beta_tensor
->
data
.
numel
();
NVTE_CHECK
(
alpha_numel
==
num_tensors
,
"Grouped GEMM: alpha must have num_tensors ("
,
num_tensors
,
") elements, got "
,
alpha_numel
);
NVTE_CHECK
(
beta_numel
==
num_tensors
,
"Grouped GEMM: beta must have num_tensors ("
,
num_tensors
,
") elements, got "
,
beta_numel
);
auto
is_fp8_or_16bit
=
[](
transformer_engine
::
DType
dtype
)
{
return
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
||
dtype
==
transformer_engine
::
DType
::
kBFloat16
||
dtype
==
transformer_engine
::
DType
::
kFloat16
;
};
auto
is_output_dtype
=
[](
transformer_engine
::
DType
dtype
)
{
return
dtype
==
transformer_engine
::
DType
::
kBFloat16
||
dtype
==
transformer_engine
::
DType
::
kFloat16
||
dtype
==
transformer_engine
::
DType
::
kFloat32
;
};
NVTE_CHECK
(
is_fp8_or_16bit
(
inputA
->
dtype
())
&&
is_fp8_or_16bit
(
inputB
->
dtype
()),
"Grouped GEMM inputs must be FP8, BF16, or FP16."
);
// Only check C dtype if C is provided
if
(
inputC
!=
nullptr
)
{
NVTE_CHECK
(
is_output_dtype
(
inputC
->
dtype
()),
"Grouped GEMM: C must be BF16, FP16, or FP32."
);
}
NVTE_CHECK
(
is_output_dtype
(
outputD
->
dtype
()),
"Grouped GEMM: D must be BF16, FP16, or FP32."
);
NVTE_CHECK
(
inputA
->
has_data
()
||
inputA
->
has_columnwise_data
(),
"Grouped GEMM: A tensor is missing both row-wise and column-wise data"
);
NVTE_CHECK
(
inputB
->
has_data
()
||
inputB
->
has_columnwise_data
(),
"Grouped GEMM: B tensor is missing both row-wise and column-wise data"
);
}
// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM.
// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and
// fallback to column-wise data when row-wise is absent.
// Contains all information needed for GEMM setup - shape already accounts for storage layout.
struct
GroupedOperandSelection
{
TensorShapeInfo
shape
;
// Shape info with dims already swapped for columnwise if needed
char
*
dptr
=
nullptr
;
void
*
scale_inv
=
nullptr
;
transformer_engine
::
DType
dtype
=
transformer_engine
::
DType
::
kNumTypes
;
bool
trans
=
false
;
};
// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims.
// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage.
// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor).
inline
TensorShapeInfo
create_shape_info
(
const
transformer_engine
::
GroupedTensor
*
t
,
bool
swap_dims
)
{
const
bool
has_first
=
t
->
first_dims
.
has_data
();
const
bool
has_last
=
t
->
last_dims
.
has_data
();
NVTE_CHECK
(
has_first
||
t
->
all_same_first_dim
(),
"GroupedTensor is missing first_dims for varying shapes"
);
NVTE_CHECK
(
has_last
||
t
->
all_same_last_dim
(),
"GroupedTensor is missing last_dims for varying shapes"
);
const
int64_t
*
first_ptr
=
has_first
?
static_cast
<
const
int64_t
*>
(
t
->
first_dims
.
dptr
)
:
nullptr
;
const
int64_t
*
last_ptr
=
has_last
?
static_cast
<
const
int64_t
*>
(
t
->
last_dims
.
dptr
)
:
nullptr
;
const
int64_t
uniform_first
=
has_first
?
0
:
static_cast
<
int64_t
>
(
t
->
get_common_first_dim
());
const
int64_t
uniform_last
=
has_last
?
0
:
static_cast
<
int64_t
>
(
t
->
get_common_last_dim
());
const
int64_t
*
offsets_ptr
=
t
->
tensor_offsets
.
has_data
()
?
static_cast
<
const
int64_t
*>
(
t
->
tensor_offsets
.
dptr
)
:
nullptr
;
if
(
swap_dims
)
{
// Swap first/last to account for columnwise (transposed) storage
return
{
last_ptr
,
first_ptr
,
offsets_ptr
,
uniform_last
,
uniform_first
};
}
return
{
first_ptr
,
last_ptr
,
offsets_ptr
,
uniform_first
,
uniform_last
};
}
inline
GroupedOperandSelection
select_grouped_operand
(
const
transformer_engine
::
GroupedTensor
*
t
,
bool
trans
,
bool
is_A
)
{
using
namespace
transformer_engine
;
const
bool
has_row
=
t
->
has_data
();
const
bool
has_col
=
t
->
has_columnwise_data
();
NVTE_CHECK
(
has_row
||
has_col
,
"Grouped GEMM operand is missing both row-wise and column-wise data"
);
// Currently only unquantized data and tensor-scaled FP8 are supported.
const
auto
sm
=
t
->
scaling_mode
;
NVTE_CHECK
(
sm
==
NVTE_DELAYED_TENSOR_SCALING
,
"Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data"
);
const
DType
row_dtype
=
t
->
data
.
dtype
;
const
DType
col_dtype
=
t
->
columnwise_data
.
dtype
;
GroupedOperandSelection
sel
;
sel
.
trans
=
trans
;
const
DType
rep_dtype
=
has_row
?
row_dtype
:
col_dtype
;
const
bool
is_fp8
=
is_fp8_dtype
(
rep_dtype
);
const
bool
non_tn_fp8_ok
=
nvte_is_non_tn_fp8_gemm_supported
();
// Helper to select columnwise storage (swaps dims in shape)
auto
use_columnwise
=
[
&
]()
{
sel
.
dptr
=
static_cast
<
char
*>
(
t
->
columnwise_data
.
dptr
);
sel
.
scale_inv
=
t
->
columnwise_scale_inv
.
dptr
;
sel
.
dtype
=
col_dtype
;
sel
.
shape
=
create_shape_info
(
t
,
/*swap_dims=*/
true
);
};
// Helper to select row-wise storage
auto
use_rowwise
=
[
&
]()
{
sel
.
dptr
=
static_cast
<
char
*>
(
t
->
data
.
dptr
);
sel
.
scale_inv
=
t
->
scale_inv
.
dptr
;
sel
.
dtype
=
row_dtype
;
sel
.
shape
=
create_shape_info
(
t
,
/*swap_dims=*/
false
);
};
// Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed.
if
(
is_fp8
&&
!
non_tn_fp8_ok
)
{
if
(
is_A
)
{
if
(
!
sel
.
trans
)
{
NVTE_CHECK
(
has_col
,
"Grouped GEMM: A is missing column-wise data needed for FP8 TN layout"
);
use_columnwise
();
sel
.
trans
=
true
;
// using pre-transposed storage
return
sel
;
}
}
else
{
// B
if
(
sel
.
trans
)
{
NVTE_CHECK
(
has_col
,
"Grouped GEMM: B is missing column-wise data needed for FP8 TN layout"
);
use_columnwise
();
sel
.
trans
=
false
;
// using pre-transposed storage
return
sel
;
}
}
}
// If only column-wise data is available, mirror the transpose flag (pre-transposed storage).
if
(
!
has_row
&&
has_col
)
{
// On Hopper FP8, this would break TN requirement - should have been handled above
NVTE_CHECK
(
!
is_fp8
||
non_tn_fp8_ok
,
"Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration"
);
use_columnwise
();
sel
.
trans
=
!
trans
;
// flip transpose for pre-transposed storage
return
sel
;
}
// Default: use row-wise data
use_rowwise
();
return
sel
;
}
inline
void
*
validate_and_get_workspace_ptr
(
transformer_engine
::
Tensor
*
ws
,
size_t
required_size
,
const
char
*
workspace_name
)
{
NVTE_CHECK
(
ws
!=
nullptr
,
workspace_name
,
" tensor is null."
);
const
size_t
provided_size
=
get_buffer_size_bytes
(
ws
->
data
.
numel
(),
ws
->
data
.
dtype
);
NVTE_CHECK
(
provided_size
>=
required_size
,
"Grouped GEMM: Insufficient "
,
workspace_name
,
". Required: "
,
required_size
,
" bytes, Available: "
,
provided_size
,
" bytes."
);
return
ws
->
data
.
dptr
;
}
inline
void
init_matrix_layouts
(
cublasLtMatrixLayoutOpaque_t
&
descA
,
cublasLtMatrixLayoutOpaque_t
&
descB
,
cublasLtMatrixLayoutOpaque_t
&
descC
,
cublasLtMatrixLayoutOpaque_t
&
descD
,
const
GroupedGemmSetupWorkspace
&
ws
,
const
GroupedOperandSelection
&
A_sel
,
const
GroupedOperandSelection
&
B_sel
,
const
transformer_engine
::
GroupedTensor
*
D
,
size_t
num_tensors
)
{
const
cudaDataType_t
A_type
=
get_cuda_dtype
(
A_sel
.
dtype
);
const
cudaDataType_t
B_type
=
get_cuda_dtype
(
B_sel
.
dtype
);
const
cudaDataType_t
D_type
=
get_cuda_dtype
(
D
->
dtype
());
// Storage dimensions computed by kernel, leading dimension = rows
NVTE_CHECK_CUBLAS
(
cublasLtGroupedMatrixLayoutInit
(
&
descA
,
A_type
,
num_tensors
,
ws
.
a_rows
,
ws
.
a_cols
,
ws
.
a_rows
));
NVTE_CHECK_CUBLAS
(
cublasLtGroupedMatrixLayoutInit
(
&
descB
,
B_type
,
num_tensors
,
ws
.
b_rows
,
ws
.
b_cols
,
ws
.
b_rows
));
NVTE_CHECK_CUBLAS
(
cublasLtGroupedMatrixLayoutInit
(
&
descC
,
D_type
,
num_tensors
,
ws
.
d_rows
,
ws
.
d_cols
,
ws
.
d_rows
));
NVTE_CHECK_CUBLAS
(
cublasLtGroupedMatrixLayoutInit
(
&
descD
,
D_type
,
num_tensors
,
ws
.
d_rows
,
ws
.
d_cols
,
ws
.
d_rows
));
}
inline
void
init_matmul_desc
(
cublasLtMatmulDescOpaque_t
&
matmulDesc
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
)
{
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescInit
(
&
matmulDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
op_A
,
sizeof
(
op_A
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
op_B
,
sizeof
(
op_B
)));
cublasLtPointerMode_t
pointer_mode
=
CUBLASLT_POINTER_MODE_DEVICE
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
pointer_mode
,
sizeof
(
pointer_mode
)));
int64_t
alphabeta_batch_stride
=
1
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE
,
&
alphabeta_batch_stride
,
sizeof
(
int64_t
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE
,
&
alphabeta_batch_stride
,
sizeof
(
int64_t
)));
}
inline
void
set_fp8_scale_pointers
(
cublasLtMatmulDescOpaque_t
&
matmulDesc
,
const
GroupedOperandSelection
&
A_sel
,
const
GroupedOperandSelection
&
B_sel
)
{
const
bool
is_fp8_a
=
is_fp8_dtype
(
A_sel
.
dtype
);
const
bool
is_fp8_b
=
is_fp8_dtype
(
B_sel
.
dtype
);
if
(
!
is_fp8_a
&&
!
is_fp8_b
)
return
;
if
(
is_fp8_a
)
{
void
*
a_scale_inv
=
A_sel
.
scale_inv
;
NVTE_CHECK
(
a_scale_inv
!=
nullptr
,
"FP8 grouped GEMM: A scale_inv is required"
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
a_scale_inv
,
sizeof
(
a_scale_inv
)));
}
if
(
is_fp8_b
)
{
void
*
b_scale_inv
=
B_sel
.
scale_inv
;
NVTE_CHECK
(
b_scale_inv
!=
nullptr
,
"FP8 grouped GEMM: B scale_inv is required"
);
NVTE_CHECK_CUBLAS
(
cublasLtMatmulDescSetAttribute
(
&
matmulDesc
,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
b_scale_inv
,
sizeof
(
b_scale_inv
)));
}
}
// Constants for grouped GEMM workspace (declared early for use in heuristics)
static
constexpr
size_t
kGroupedGemmAlignment
=
256
;
static
constexpr
size_t
kGroupedGemmCublasWorkspaceSize
=
32ull
*
1024
*
1024
;
// 32 MiB
inline
cublasLtMatmulAlgo_t
select_grouped_gemm_algo
(
cublasLtHandle_t
handle
,
cublasLtMatmulDescOpaque_t
&
matmulDesc
,
cublasLtMatrixLayoutOpaque_t
&
descA
,
cublasLtMatrixLayoutOpaque_t
&
descB
,
cublasLtMatrixLayoutOpaque_t
&
descC
,
cublasLtMatrixLayoutOpaque_t
&
descD
,
int64_t
avg_m
,
int64_t
avg_n
,
int64_t
avg_k
)
{
cublasLtMatmulPreferenceOpaque_t
preference
;
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceInit
(
&
preference
));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
kGroupedGemmCublasWorkspaceSize
,
sizeof
(
size_t
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS
,
&
avg_m
,
sizeof
(
int64_t
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS
,
&
avg_n
,
sizeof
(
int64_t
)));
NVTE_CHECK_CUBLAS
(
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM
,
&
avg_k
,
sizeof
(
int64_t
)));
cublasLtMatmulHeuristicResult_t
heuristicResult
;
int
returnedResults
=
0
;
auto
status
=
cublasLtMatmulAlgoGetHeuristic
(
handle
,
&
matmulDesc
,
&
descA
,
&
descB
,
&
descC
,
&
descD
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
NVTE_CHECK
(
status
!=
CUBLAS_STATUS_NOT_SUPPORTED
,
"Unable to find suitable cuBLAS grouped GEMM algorithm"
);
NVTE_CHECK_CUBLAS
(
status
);
NVTE_CHECK
(
returnedResults
>
0
,
"No suitable algorithm found for grouped GEMM"
);
return
heuristicResult
.
algo
;
}
// Single kernel that sets up all GEMM parameters.
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel.
__global__
void
setup_grouped_gemm_kernel
(
// Output arrays
void
**
A_ptrs
,
void
**
B_ptrs
,
void
**
C_ptrs
,
void
**
D_ptrs
,
int
*
a_rows
,
int
*
a_cols
,
int
*
b_rows
,
int
*
b_cols
,
int
*
d_rows
,
int
*
d_cols
,
float
**
alpha_ptrs
,
float
**
beta_ptrs
,
// Inputs
char
*
a_base
,
char
*
b_base
,
char
*
c_base
,
char
*
d_base
,
TensorShapeInfo
A_meta
,
TensorShapeInfo
B_meta
,
TensorShapeInfo
C_meta
,
TensorShapeInfo
D_meta
,
size_t
a_elem_size
,
size_t
b_elem_size
,
size_t
c_elem_size
,
size_t
d_elem_size
,
float
*
alpha_ptr
,
float
*
beta_ptr
,
size_t
num_tensors
)
{
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
num_tensors
)
return
;
// Get dimensions for this tensor (from array or uniform value)
int64_t
a_first
=
A_meta
.
first_dims
?
A_meta
.
first_dims
[
idx
]
:
A_meta
.
uniform_first
;
int64_t
a_last
=
A_meta
.
last_dims
?
A_meta
.
last_dims
[
idx
]
:
A_meta
.
uniform_last
;
int64_t
b_first
=
B_meta
.
first_dims
?
B_meta
.
first_dims
[
idx
]
:
B_meta
.
uniform_first
;
int64_t
b_last
=
B_meta
.
last_dims
?
B_meta
.
last_dims
[
idx
]
:
B_meta
.
uniform_last
;
int64_t
d_first
=
D_meta
.
first_dims
?
D_meta
.
first_dims
[
idx
]
:
D_meta
.
uniform_first
;
int64_t
d_last
=
D_meta
.
last_dims
?
D_meta
.
last_dims
[
idx
]
:
D_meta
.
uniform_last
;
// Compute offsets (from array or compute from uniform dims)
int64_t
a_offset
=
A_meta
.
offsets
?
A_meta
.
offsets
[
idx
]
:
(
idx
*
A_meta
.
uniform_first
*
A_meta
.
uniform_last
);
int64_t
b_offset
=
B_meta
.
offsets
?
B_meta
.
offsets
[
idx
]
:
(
idx
*
B_meta
.
uniform_first
*
B_meta
.
uniform_last
);
int64_t
c_offset
=
C_meta
.
offsets
?
C_meta
.
offsets
[
idx
]
:
(
idx
*
C_meta
.
uniform_first
*
C_meta
.
uniform_last
);
int64_t
d_offset
=
D_meta
.
offsets
?
D_meta
.
offsets
[
idx
]
:
(
idx
*
D_meta
.
uniform_first
*
D_meta
.
uniform_last
);
// Compute data pointers
A_ptrs
[
idx
]
=
a_base
+
a_offset
*
a_elem_size
;
B_ptrs
[
idx
]
=
b_base
+
b_offset
*
b_elem_size
;
C_ptrs
[
idx
]
=
c_base
+
c_offset
*
c_elem_size
;
D_ptrs
[
idx
]
=
d_base
+
d_offset
*
d_elem_size
;
// Compute storage dimensions for cuBLAS matrix layouts.
// For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS,
// so rows=last, cols=first. For columnwise, dims are already swapped.
a_rows
[
idx
]
=
static_cast
<
int
>
(
a_last
);
a_cols
[
idx
]
=
static_cast
<
int
>
(
a_first
);
b_rows
[
idx
]
=
static_cast
<
int
>
(
b_last
);
b_cols
[
idx
]
=
static_cast
<
int
>
(
b_first
);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows
[
idx
]
=
static_cast
<
int
>
(
d_first
);
d_cols
[
idx
]
=
static_cast
<
int
>
(
d_last
);
// Fill alpha/beta pointers (per-matrix)
alpha_ptrs
[
idx
]
=
alpha_ptr
+
idx
;
beta_ptrs
[
idx
]
=
beta_ptr
+
idx
;
}
// Launch the setup kernel to populate workspace arrays
inline
void
launch_grouped_gemm_setup
(
const
GroupedGemmSetupWorkspace
&
ws
,
const
GroupedOperandSelection
&
A_sel
,
const
GroupedOperandSelection
&
B_sel
,
const
transformer_engine
::
GroupedTensor
*
C
,
const
transformer_engine
::
GroupedTensor
*
D
,
const
transformer_engine
::
Tensor
*
alpha_tensor
,
const
transformer_engine
::
Tensor
*
beta_tensor
,
size_t
num_tensors
,
cudaStream_t
stream
)
{
// Use shape info from selection (already accounts for columnwise dimension swap)
TensorShapeInfo
A_meta
=
A_sel
.
shape
;
TensorShapeInfo
B_meta
=
B_sel
.
shape
;
TensorShapeInfo
C_meta
=
TensorShapeInfo
::
create_shape_info_for_C
(
C
,
D
);
TensorShapeInfo
D_meta
=
TensorShapeInfo
::
from_tensor
(
D
);
char
*
c_base
=
static_cast
<
char
*>
(
C
->
data
.
dptr
);
char
*
d_base
=
static_cast
<
char
*>
(
D
->
data
.
dptr
);
const
size_t
a_elem_size
=
transformer_engine
::
typeToSize
(
A_sel
.
dtype
);
const
size_t
b_elem_size
=
transformer_engine
::
typeToSize
(
B_sel
.
dtype
);
const
size_t
c_elem_size
=
transformer_engine
::
typeToSize
(
C
->
dtype
());
const
size_t
d_elem_size
=
transformer_engine
::
typeToSize
(
D
->
dtype
());
const
int
threads_per_block
=
256
;
const
int
num_blocks
=
(
num_tensors
+
threads_per_block
-
1
)
/
threads_per_block
;
setup_grouped_gemm_kernel
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
ws
.
A_ptrs
,
ws
.
B_ptrs
,
ws
.
C_ptrs
,
ws
.
D_ptrs
,
ws
.
a_rows
,
ws
.
a_cols
,
ws
.
b_rows
,
ws
.
b_cols
,
ws
.
d_rows
,
ws
.
d_cols
,
ws
.
alpha_ptrs
,
ws
.
beta_ptrs
,
A_sel
.
dptr
,
B_sel
.
dptr
,
c_base
,
d_base
,
A_meta
,
B_meta
,
C_meta
,
D_meta
,
a_elem_size
,
b_elem_size
,
c_elem_size
,
d_elem_size
,
static_cast
<
float
*>
(
alpha_tensor
->
data
.
dptr
),
static_cast
<
float
*>
(
beta_tensor
->
data
.
dptr
),
num_tensors
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
inline
size_t
grouped_gemm_setup_workspace_size
(
size_t
num_tensors
)
{
return
GroupedGemmSetupWorkspace
::
required_setup_size
(
num_tensors
,
kGroupedGemmAlignment
);
}
}
// namespace
void
nvte_grouped_gemm
(
const
NVTEGroupedTensor
A
,
int
transa
,
const
NVTEGroupedTensor
B
,
int
transb
,
const
NVTEGroupedTensor
C
,
NVTEGroupedTensor
D
,
const
NVTETensor
alpha
,
const
NVTETensor
beta
,
NVTETensor
workspace_setup
,
NVTETensor
workspace_cublas
,
NVTEGroupedMatmulConfig
config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_grouped_gemm
);
using
namespace
transformer_engine
;
// Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+
const
int
current_device
=
transformer_engine
::
cuda
::
current_device
();
NVTE_CHECK
(
transformer_engine
::
cuda
::
sm_arch
(
current_device
)
>=
100
,
"nvte_grouped_gemm requires Blackwell (SM100) or newer architecture."
);
NVTE_CHECK
(
transformer_engine
::
cuda
::
cublas_version
()
>=
130200
,
"nvte_grouped_gemm requires cuBLAS 13.2+, but run-time cuBLAS version is "
,
transformer_engine
::
cuda
::
cublas_version
());
// Convert to internal types
const
GroupedTensor
*
inputA
=
convertNVTEGroupedTensorCheck
(
A
);
const
GroupedTensor
*
inputB
=
convertNVTEGroupedTensorCheck
(
B
);
const
GroupedTensor
*
inputC_raw
=
convertNVTEGroupedTensor
(
C
);
// Can be NULL
GroupedTensor
*
outputD
=
convertNVTEGroupedTensorCheck
(
D
);
const
Tensor
*
alpha_tensor
=
convertNVTETensorCheck
(
alpha
);
const
Tensor
*
beta_tensor
=
convertNVTETensorCheck
(
beta
);
Tensor
*
wspace_setup
=
convertNVTETensor
(
workspace_setup
);
Tensor
*
wspace_cublas
=
convertNVTETensor
(
workspace_cublas
);
// Parse config (if provided)
GroupedMatmulConfig
config_
;
if
(
config
!=
nullptr
)
{
config_
=
*
reinterpret_cast
<
GroupedMatmulConfig
*>
(
config
);
}
// Validate inputs and num_tensors
validate_grouped_gemm_inputs
(
inputA
,
inputB
,
inputC_raw
,
outputD
,
alpha_tensor
,
beta_tensor
);
// If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data)
const
GroupedTensor
*
inputC
=
(
inputC_raw
!=
nullptr
)
?
inputC_raw
:
outputD
;
const
size_t
num_tensors
=
inputA
->
num_tensors
;
// Select operand storage (row-wise vs column-wise) and adjust transpose flags to
// mirror the non-grouped GEMM logic for FP8 layout constraints.
const
auto
A_sel
=
select_grouped_operand
(
inputA
,
static_cast
<
bool
>
(
transa
),
/*is_A=*/
true
);
const
auto
B_sel
=
select_grouped_operand
(
inputB
,
static_cast
<
bool
>
(
transb
),
/*is_A=*/
false
);
// Workspaces: setup (pointer arrays) and cuBLAS
const
size_t
setup_workspace_size
=
grouped_gemm_setup_workspace_size
(
num_tensors
);
const
size_t
cublas_workspace_size
=
kGroupedGemmCublasWorkspaceSize
;
void
*
setup_workspace_ptr
=
validate_and_get_workspace_ptr
(
wspace_setup
,
setup_workspace_size
,
"Grouped GEMM setup workspace"
);
void
*
cublas_workspace_ptr
=
validate_and_get_workspace_ptr
(
wspace_cublas
,
cublas_workspace_size
,
"Grouped GEMM cuBLAS workspace"
);
auto
setup_workspace
=
GroupedGemmSetupWorkspace
::
from_buffers
(
static_cast
<
char
*>
(
setup_workspace_ptr
),
num_tensors
);
launch_grouped_gemm_setup
(
setup_workspace
,
A_sel
,
B_sel
,
inputC
,
outputD
,
alpha_tensor
,
beta_tensor
,
num_tensors
,
stream
);
// Get cuBLAS handle
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
cublasLtHandle_t
handle
=
cublasHandleManager
::
Instance
().
GetHandle
();
// Setup cuBLAS operations
cublasOperation_t
op_A
=
A_sel
.
trans
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
op_B
=
B_sel
.
trans
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
// Create grouped matrix layouts
cublasLtMatrixLayoutOpaque_t
descA
,
descB
,
descC
,
descD
;
init_matrix_layouts
(
descA
,
descB
,
descC
,
descD
,
setup_workspace
,
A_sel
,
B_sel
,
outputD
,
num_tensors
);
// Create matmul descriptor
cublasLtMatmulDescOpaque_t
matmulDesc
;
init_matmul_desc
(
matmulDesc
,
op_A
,
op_B
);
set_fp8_scale_pointers
(
matmulDesc
,
A_sel
,
B_sel
);
// Compute average dimensions for heuristics
// K dimension: if transa, K is A's first dim; if not, K is A's last dim
// Use original inputA and transa for heuristics (not modified A_sel.trans)
int64_t
avg_m_val
=
config_
.
avg_m
.
value_or
(
compute_avg_first_dim
(
outputD
));
int64_t
avg_n_val
=
config_
.
avg_n
.
value_or
(
compute_avg_last_dim
(
outputD
));
int64_t
avg_k_val
=
config_
.
avg_k
.
value_or
(
transa
?
compute_avg_first_dim
(
inputA
)
:
compute_avg_last_dim
(
inputA
));
// Heuristic selection
cublasLtMatmulAlgo_t
algo
=
select_grouped_gemm_algo
(
handle
,
matmulDesc
,
descA
,
descB
,
descC
,
descD
,
avg_m_val
,
avg_n_val
,
avg_k_val
);
// Execute the grouped GEMM
NVTE_CHECK_CUBLAS
(
cublasLtMatmul
(
handle
,
&
matmulDesc
,
setup_workspace
.
alpha_ptrs
,
setup_workspace
.
A_ptrs
,
&
descA
,
setup_workspace
.
B_ptrs
,
&
descB
,
setup_workspace
.
beta_ptrs
,
setup_workspace
.
C_ptrs
,
&
descC
,
setup_workspace
.
D_ptrs
,
&
descD
,
&
algo
,
cublas_workspace_ptr
,
kGroupedGemmCublasWorkspaceSize
,
stream
));
}
#else // CUBLAS_VERSION < 130200
void
nvte_grouped_gemm
(
const
NVTEGroupedTensor
A
,
int
transa
,
const
NVTEGroupedTensor
B
,
int
transb
,
const
NVTEGroupedTensor
C
,
NVTEGroupedTensor
D
,
const
NVTETensor
alpha
,
const
NVTETensor
beta
,
NVTETensor
workspace_setup
,
NVTETensor
workspace_cublas
,
NVTEGroupedMatmulConfig
config
,
cudaStream_t
stream
)
{
NVTE_ERROR
(
"nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is "
,
CUBLAS_VERSION
,
". Please upgrade to CUDA 13.1 or newer."
);
}
#endif // CUBLAS_VERSION >= 130200
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_tensor.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "hadamard_transform_utils.cuh"
namespace
transformer_engine
{
namespace
{
constexpr
int
kMaxTensorsPerKernel
=
64
;
constexpr
int
kThreadsPerWarp
=
32
;
enum
ShapeRepresentation
{
SAME_BOTH_DIMS
=
0
,
VARYING_FIRST_DIM
=
1
,
VARYING_LAST_DIM
=
2
,
VARYING_BOTH_DIMS
=
3
};
__device__
__forceinline__
size_t
get_current_tensor_id
(
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
current_offset
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
)
{
if
(
shape_rep
==
ShapeRepresentation
::
SAME_BOTH_DIMS
)
{
const
size_t
current_row
=
current_offset
/
last_logical_dim
;
const
size_t
rows_per_tensor
=
first_logical_dim
/
num_tensors
;
return
current_row
/
rows_per_tensor
;
}
else
{
// upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors)
size_t
low
=
0
;
size_t
hi
=
num_tensors
;
// half-open [low, hi)
while
(
low
<
hi
)
{
const
size_t
mid
=
low
+
(
hi
-
low
)
/
2
;
const
size_t
mid_offset
=
static_cast
<
size_t
>
(
offsets_ptr
[
mid
]);
if
(
mid_offset
<=
current_offset
)
{
low
=
mid
+
1
;
}
else
{
hi
=
mid
;
}
}
// low = first index where offsets[low] > current_offset (or low == num_tensors)
// id = low - 1, but need to evaluate if current_offset < offsets[0]
return
(
low
==
0
)
?
0
:
(
low
-
1
);
}
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ComputeKernel
(
uint32_t
b_frag_i
[
4
],
uint32_t
b_frag_t
[
4
],
IType
*
in_sh_ptr
,
uint32_t
&
local_pre_rht_amax_reg
,
uint32_t
&
local_amax_reg
,
uint32_t
&
local_amax_t_reg
)
{
uint32_t
a_frag
[
4
];
// A matrix fragment
uint32_t
c_frag
[
4
];
// Result fragment
int
warp_id
=
threadIdx
.
x
/
kThreadsPerWarp
;
int
local_rank
=
(
threadIdx
.
x
%
kThreadsPerWarp
);
int
ld_row_idx
=
local_rank
%
kHadamardDimension
;
int
ld_col_idx
=
local_rank
/
kHadamardDimension
+
warp_id
*
2
;
int
swizzle_idx
=
swizzle_128B_atom_32B
(
ld_row_idx
,
ld_col_idx
);
uint32_t
temp_amax_reg
;
uint32_t
temp_amax_t_reg
;
if
(
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnIdentityAmax
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
b_frag_i
[
0
],
b_frag_i
[
1
],
b_frag_i
[
2
],
b_frag_i
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_reg
)
:
"r"
(
local_amax_reg
),
"r"
(
temp_amax_reg
));
}
if
(
kReturnTransposedAmax
)
{
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if
(
!
kReturnIdentityAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
0
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
1
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
2
]);
matrix_transpose_m8_n8_b16_inplace
(
a_frag
[
3
]);
mma_m16_n16_k16_b16_b16_b16_noacc
<
kReturnTransposedAmax
>
(
a_frag
[
0
],
a_frag
[
2
],
a_frag
[
1
],
a_frag
[
3
],
b_frag_t
[
0
],
b_frag_t
[
1
],
b_frag_t
[
2
],
b_frag_t
[
3
],
c_frag
[
0
],
c_frag
[
1
],
c_frag
[
2
],
c_frag
[
3
],
temp_amax_t_reg
);
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_amax_t_reg
)
:
"r"
(
local_amax_t_reg
),
"r"
(
temp_amax_t_reg
));
}
if
(
kReturnPreRhtAmax
)
{
if
(
!
kReturnIdentityAmax
&&
!
kReturnTransposedAmax
)
{
ldmatrix_x4_m8n8_shared_b16
<
false
>
(
a_frag
[
0
],
a_frag
[
1
],
a_frag
[
2
],
a_frag
[
3
],
reinterpret_cast
<
uint4
*>
(
in_sh_ptr
)
+
swizzle_idx
);
}
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
1
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
2
])
:
"r"
(
a_frag
[
2
]),
"r"
(
a_frag
[
3
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
a_frag
[
0
])
:
"r"
(
a_frag
[
0
]),
"r"
(
a_frag
[
2
]));
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;
\n\t
"
:
"=r"
(
local_pre_rht_amax_reg
)
:
"r"
(
a_frag
[
0
]),
"r"
(
local_pre_rht_amax_reg
));
}
}
template
<
int
kN
>
__device__
__host__
constexpr
int
NextPowerOf2
()
{
static_assert
(
kN
>
0
,
"kN must be > 0"
);
// Round up to the next power of 2 by counting leading zeros.
return
1
<<
(
32
-
__builtin_clz
(
kN
-
1
));
}
template
<
int
kNumWarps
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__device__
__forceinline__
void
ReduceMax
(
const
float
pre_rht_amax
,
const
float
identity_amax
,
const
float
transpose_amax
,
float
*
staging_for_pre_rht
,
float
*
staging_for_identity
,
float
*
staging_for_transpose
,
float
*
output_pre_rht_amax_ptr
,
float
*
output_identity_amax_ptr
,
float
*
output_transpose_amax_ptr
,
const
int
warpid
)
{
// intra-warp reduction
constexpr
int
kWarpSize
=
32
;
int
local_rank
=
threadIdx
.
x
%
32
;
float
warp_pre_rht_amax
=
kReturnPreRhtAmax
?
warp_reduce_max
<
kWarpSize
>
(
pre_rht_amax
)
:
0.0
f
;
float
warp_identity_amax
=
kReturnIdentityAmax
?
warp_reduce_max
<
kWarpSize
>
(
identity_amax
)
:
0.0
f
;
float
warp_transpose_amax
=
kReturnTransposedAmax
?
warp_reduce_max
<
kWarpSize
>
(
transpose_amax
)
:
0.0
f
;
// inter-warp reduction
if
(
threadIdx
.
x
%
32
==
0
)
{
if
(
kReturnPreRhtAmax
)
{
staging_for_pre_rht
[
warpid
]
=
warp_pre_rht_amax
;
}
if
(
kReturnIdentityAmax
)
{
staging_for_identity
[
warpid
]
=
warp_identity_amax
;
}
if
(
kReturnTransposedAmax
)
{
staging_for_transpose
[
warpid
]
=
warp_transpose_amax
;
}
}
__syncthreads
();
constexpr
int
kNumWarpsPow2
=
NextPowerOf2
<
kNumWarps
>
();
if
(
warpid
==
0
)
{
if
(
kReturnIdentityAmax
)
{
float
identity_accum
=
local_rank
<
kNumWarps
?
staging_for_identity
[
local_rank
]
:
0.0
f
;
identity_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
identity_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_identity_amax_ptr
,
identity_accum
);
}
}
}
if
(
warpid
==
1
)
{
if
(
kReturnTransposedAmax
)
{
float
transpose_accum
=
local_rank
<
kNumWarps
?
staging_for_transpose
[
local_rank
]
:
0.0
f
;
transpose_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
transpose_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_transpose_amax_ptr
,
transpose_accum
);
}
}
}
if
(
warpid
==
2
)
{
if
(
kReturnPreRhtAmax
)
{
float
pre_rht_accum
=
local_rank
<
kNumWarps
?
staging_for_pre_rht
[
local_rank
]
:
0.0
f
;
pre_rht_accum
=
warp_reduce_max
<
kNumWarpsPow2
>
(
pre_rht_accum
);
if
(
local_rank
==
0
)
{
atomicMaxFloat
(
output_pre_rht_amax_ptr
,
pre_rht_accum
);
}
}
}
}
__global__
void
GraphSafeMultiZeroAmaxKernel
(
const
size_t
num_tensors
,
float
*
amax_rowwise_ptr
,
float
*
amax_colwise_ptr
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
// Assign each thread a range for rowwise and colwise independently
if
(
amax_rowwise_ptr
!=
nullptr
)
{
for
(
int
i
=
tid
;
i
<
num_tensors
;
i
+=
stride
)
{
amax_rowwise_ptr
[
i
]
=
0.
f
;
}
}
if
(
amax_colwise_ptr
!=
nullptr
)
{
for
(
int
i
=
tid
;
i
<
num_tensors
;
i
+=
stride
)
{
amax_colwise_ptr
[
i
]
=
0.
f
;
}
}
}
__global__
void
GraphSafeMultiAmaxMemcpyD2DKernelPreRHT
(
const
size_t
num_tensors
,
float
*
amax_rowwise_ptr
,
float
*
amax_colwise_ptr
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
if
(
amax_rowwise_ptr
!=
nullptr
&&
amax_colwise_ptr
!=
nullptr
)
{
for
(;
tid
<
num_tensors
;
tid
+=
stride
)
{
float
*
output_pre_rht_amax_ptr
=
amax_rowwise_ptr
+
tid
;
float
*
output_transpose_amax_ptr
=
amax_colwise_ptr
+
tid
;
*
output_transpose_amax_ptr
=
*
output_pre_rht_amax_ptr
;
}
}
}
template
<
typename
IType
,
int
kHadamardDimension
,
int
CHUNK_DIM_Y
,
int
CHUNK_DIM_X
,
int
BUFF_DIM_Y
,
int
BUFF_DIM_X
,
int
THREADS_PER_CHUNK
,
int
THREADS_PER_Y
,
bool
kReturnPreRhtAmax
,
bool
kReturnIdentityAmax
,
bool
kReturnTransposedAmax
>
__global__
void
GraphSafeGroupHadamardAmaxTmaKernel
(
const
__grid_constant__
CUtensorMap
tensor_map_input
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
,
const
int64_t
*
const
__restrict__
first_dims_ptr
,
float
*
const
__restrict__
amax_rowwise_ptr
,
float
*
const
__restrict__
amax_colwise_ptr
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float
*
output_pre_rht_amax_ptr
;
float
*
output_identity_amax_ptr
=
nullptr
;
float
*
output_transpose_amax_ptr
;
// calculate the global offset to get tensor id
size_t
global_offset
=
blockIdx
.
y
*
CHUNK_DIM_Y
*
last_logical_dim
;
int
tensor_id
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
global_offset
,
first_logical_dim
,
last_logical_dim
,
offsets_ptr
);
output_pre_rht_amax_ptr
=
static_cast
<
float
*>
(
amax_rowwise_ptr
)
+
tensor_id
;
output_transpose_amax_ptr
=
static_cast
<
float
*>
(
amax_colwise_ptr
)
+
tensor_id
;
static_assert
(
CHUNK_DIM_Y
>=
BUFF_DIM_Y
&&
CHUNK_DIM_Y
%
BUFF_DIM_Y
==
0
);
static_assert
(
CHUNK_DIM_X
>=
BUFF_DIM_X
&&
CHUNK_DIM_X
%
BUFF_DIM_X
==
0
);
constexpr
size_t
STAGES_Y
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
constexpr
size_t
STAGES_X
=
CHUNK_DIM_X
/
BUFF_DIM_X
;
constexpr
int
kNumWarps
=
(
THREADS_PER_CHUNK
*
THREADS_PER_Y
)
/
kThreadsPerWarp
;
const
int
input_block_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
input_block_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
extern
__shared__
__align__
(
128
)
char
dynamic_shmem
[];
uintptr_t
base_shmem_ptr
=
reinterpret_cast
<
uintptr_t
>
(
dynamic_shmem
);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t
*
dshmem
=
reinterpret_cast
<
uint8_t
*>
((
base_shmem_ptr
+
127
)
&
~
127ULL
);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr
size_t
in_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
IType
*
in_sh_0
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_sh_1
=
reinterpret_cast
<
IType
*>
(
dshmem
);
dshmem
+=
in_buff_size
;
IType
*
in_shs
[
2
]
=
{
in_sh_0
,
in_sh_1
};
constexpr
int
shmem_buff_size
=
BUFF_DIM_X
*
BUFF_DIM_Y
*
sizeof
(
IType
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t
*
mbar
=
reinterpret_cast
<
uint64_t
*>
(
dshmem
);
dshmem
+=
sizeof
(
uint64_t
)
*
(
STAGES_X
*
STAGES_Y
);
float
*
max_staging_identity
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_transpose
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
float
*
max_staging_pre_rht
=
reinterpret_cast
<
float
*>
(
dshmem
);
dshmem
+=
sizeof
(
float
)
*
kNumWarps
;
initialize_barriers
<
STAGES_X
*
STAGES_Y
,
THREADS_PER_CHUNK
*
THREADS_PER_Y
>
(
mbar
,
is_master_thread
);
copy_2d_to_shared
(
in_shs
[
0
],
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_block_offset_X
,
input_block_offset_Y
,
shmem_buff_size
,
&
mbar
[
0
],
is_master_thread
);
uint32_t
had_frag_i
[
4
];
uint32_t
had_frag_t
[
4
];
get_hadamard_matrix_fragment
<
kReturnIdentityAmax
,
kReturnTransposedAmax
,
false
,
false
>
(
had_frag_i
,
random_sign_mask
,
had_frag_t
,
random_sign_mask_t
);
float
local_pre_rht_amax
=
0.0
;
float
local_amax
=
0.0
;
float
local_amax_t
=
0.0
;
uint32_t
local_pre_rht_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_pre_rht_amax
);
uint32_t
local_amax_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax
);
uint32_t
local_amax_t_reg
=
*
reinterpret_cast
<
uint32_t
*>
(
&
local_amax_t
);
for
(
int
stage_y
=
0
;
stage_y
<
STAGES_Y
;
++
stage_y
)
{
for
(
int
stage_x
=
0
;
stage_x
<
STAGES_X
;
++
stage_x
)
{
int
stage
=
STAGES_X
*
stage_y
+
stage_x
;
const
int
next_stage
=
stage
+
1
;
const
int
next_stage_x
=
stage_x
+
1
==
STAGES_X
?
0
:
stage_x
+
1
;
const
int
next_stage_y
=
stage_x
+
1
==
STAGES_X
?
stage_y
+
1
:
stage_y
;
if
(
next_stage
<
STAGES_X
*
STAGES_Y
)
{
const
int
input_global_offset_Y
=
input_block_offset_Y
+
next_stage_y
*
BUFF_DIM_Y
;
const
int
input_global_offset_X
=
input_block_offset_X
+
next_stage_x
*
BUFF_DIM_X
;
copy_2d_to_shared
(
in_shs
[
next_stage
%
2
],
// ping-pong
reinterpret_cast
<
const
void
*>
(
&
tensor_map_input
),
input_global_offset_X
,
input_global_offset_Y
,
shmem_buff_size
,
&
mbar
[
next_stage
],
is_master_thread
);
}
ptx
::
fence_proxy_async_shared_cta
();
// Wait for the data to have arrived
ptx
::
mbarrier_wait_parity
(
&
mbar
[
stage
],
0
);
const
size_t
compute_stage_x_num
=
BUFF_DIM_X
/
(
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
));
const
size_t
compute_stage_y_num
=
BUFF_DIM_Y
/
(
kHadamardDimension
*
THREADS_PER_Y
);
const
size_t
in_row_stride
=
BUFF_DIM_X
;
IType
*
in_sh_ptr
=
in_shs
[
stage
%
2
];
#pragma unroll
for
(
size_t
compute_stage_y
=
0
;
compute_stage_y
<
compute_stage_y_num
;
compute_stage_y
++
)
{
const
int
row_idx_offset
=
(
compute_stage_y
*
kHadamardDimension
*
THREADS_PER_Y
+
threadIdx
.
y
*
kHadamardDimension
);
const
int
in_row_offset
=
row_idx_offset
*
in_row_stride
;
#pragma unroll
for
(
size_t
compute_stage_x
=
0
;
compute_stage_x
<
compute_stage_x_num
;
compute_stage_x
++
)
{
ComputeKernel
<
IType
,
kHadamardDimension
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
had_frag_i
,
had_frag_t
,
in_sh_ptr
+
in_row_offset
+
(
compute_stage_x
*
kHadamardDimension
*
(
THREADS_PER_CHUNK
/
kThreadsPerWarp
)),
local_pre_rht_amax_reg
,
local_amax_reg
,
local_amax_t_reg
);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads
();
}
}
}
const
int
warpid
=
(
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
)
/
kThreadsPerWarp
;
if
constexpr
(
kReturnPreRhtAmax
)
{
unpack_max_of_packed_bf16
(
local_pre_rht_amax_reg
,
local_pre_rht_amax
);
}
if
constexpr
(
kReturnIdentityAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_reg
,
local_amax
);
}
if
constexpr
(
kReturnTransposedAmax
)
{
unpack_max_of_packed_bf16
(
local_amax_t_reg
,
local_amax_t
);
}
ReduceMax
<
kNumWarps
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
(
local_pre_rht_amax
,
local_amax
,
local_amax_t
,
max_staging_pre_rht
,
max_staging_identity
,
max_staging_transpose
,
output_pre_rht_amax_ptr
,
output_identity_amax_ptr
,
output_transpose_amax_ptr
,
warpid
);
destroy_barriers
<
STAGES_X
*
STAGES_Y
>
(
mbar
,
is_master_thread
);
#else
NVTE_DEVICE_ERROR
(
"Kernel is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace
// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled
// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise
// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values
void
group_hadamard_transform_amax_graph_safe
(
const
GroupedTensor
*
input
,
GroupedTensor
*
output
,
uint16_t
random_sign_mask
,
uint16_t
random_sign_mask_t
,
bool
broadcast_pre_rht_amax
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
group_hadamard_transform_amax_graph_safe
);
#if CUDA_VERSION >= 12080
NVTE_CHECK
(
input
->
num_tensors
==
output
->
num_tensors
,
"Number of input and output tensors must be same."
);
NVTE_CHECK
(
input
->
has_data
(),
"Cannot quantize tensor without rowwise data."
);
checkCuDriverContext
(
stream
);
bool
all_return_pre_rht_amax
=
output
->
has_data
();
// there is no rowwise RHT transform in current recipe
bool
all_return_identity_amax
=
false
;
bool
all_return_transposed_amax
=
output
->
has_columnwise_data
();
NVTE_CHECK
(
all_return_pre_rht_amax
||
all_return_identity_amax
||
all_return_transposed_amax
,
"At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax "
"must be true"
);
if
(
broadcast_pre_rht_amax
)
{
NVTE_CHECK
(
all_return_pre_rht_amax
,
"broadcast_pre_rht_amax is only supported when we compute pre-RHT amax"
);
// if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything
broadcast_pre_rht_amax
&=
(
all_return_identity_amax
||
all_return_transposed_amax
);
}
const
size_t
num_tensors
=
input
->
num_tensors
;
const
size_t
first_logical_dim
=
input
->
logical_shape
.
data
[
0
];
const
size_t
last_logical_dim
=
input
->
logical_shape
.
data
[
1
];
// const size_t elts_total = first_logical_dim * last_logical_dim;
NVTE_CHECK
(
first_logical_dim
%
128
==
0
,
"First dimension of a grouped tensor should be divisible by 128."
);
NVTE_CHECK
(
last_logical_dim
%
128
==
0
,
"Last dimension of a grouped tensor should be divisible by 128."
);
float
*
const
amax_rowwise_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
float
*
const
amax_colwise_ptr
=
reinterpret_cast
<
float
*>
(
output
->
columnwise_amax
.
dptr
);
const
int64_t
*
const
offsets_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
tensor_offsets
.
dptr
);
const
int64_t
*
const
first_dims_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
first_dims
.
dptr
);
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
// some sanity checks
if
(
all_return_pre_rht_amax
)
{
NVTE_CHECK
(
amax_rowwise_ptr
!=
nullptr
,
"Amax rowwise pointer should not be nullptr."
);
}
if
(
all_return_transposed_amax
)
{
NVTE_CHECK
(
amax_colwise_ptr
!=
nullptr
,
"Amax columnwise pointer should not be nullptr."
);
}
// Multi zero out multiple amaxes if needed
dim3
block_setup_amax
(
kMaxTensorsPerKernel
);
dim3
grid_setup_amax
(
1
);
GraphSafeMultiZeroAmaxKernel
<<<
grid_setup_amax
,
block_setup_amax
,
0
,
stream
>>>
(
num_tensors
,
amax_rowwise_ptr
,
amax_colwise_ptr
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
using
IType
=
bf16
;
constexpr
int
kHadamardDimension
=
16
;
// four (1x4) 64x64 sub-tiles for ping-pong overlap
constexpr
uint64_t
kChunkBlockXSmall
=
256
;
constexpr
uint64_t
kChunkBlockYSmall
=
64
;
constexpr
uint64_t
kBuffDimX
=
64
;
constexpr
uint64_t
kBuffDimY
=
64
;
alignas
(
64
)
CUtensorMap
tensor_map_input
{};
create_2D_tensor_map
(
/*tensorMap=*/
tensor_map_input
,
/*tensor=*/
input
->
data
,
/*globalY=*/
first_logical_dim
,
/*globalX=*/
last_logical_dim
,
/*shmemY=*/
kBuffDimY
,
/*shmemX=*/
kBuffDimX
,
/*stride_elems=*/
last_logical_dim
,
/*offset_elems=*/
0
,
/*type_num_bits=*/
sizeof
(
IType
)
*
8
,
/*swizzle=*/
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B
);
constexpr
uint64_t
kThreadBlockX
=
4
;
constexpr
uint64_t
kThreadBlockY
=
1
;
constexpr
uint64_t
kNumWarps
=
kThreadBlockX
*
kThreadBlockY
;
dim3
block
(
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
);
dim3
grid
(
DIVUP
(
last_logical_dim
,
kChunkBlockXSmall
),
DIVUP
(
first_logical_dim
,
kChunkBlockYSmall
));
ShapeRepresentation
shape_rep
=
ShapeRepresentation
::
VARYING_FIRST_DIM
;
if
(
output
->
all_same_shape
())
{
shape_rep
=
ShapeRepresentation
::
SAME_BOTH_DIMS
;
}
else
if
(
output
->
all_same_first_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_LAST_DIM
;
}
else
if
(
output
->
all_same_last_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_FIRST_DIM
;
}
else
if
(
output
->
varying_both_dims
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_BOTH_DIMS
;
}
const
bool
is_const_last_dim
=
(
shape_rep
==
ShapeRepresentation
::
SAME_BOTH_DIMS
||
shape_rep
==
ShapeRepresentation
::
VARYING_FIRST_DIM
);
NVTE_CHECK
(
is_const_last_dim
,
"Currently we only support const last dimension for graph safe hadamard transform."
);
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
(
all_return_transposed_amax
&&
!
broadcast_pre_rht_amax
),
kReturnTransposedAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
(
all_return_identity_amax
&&
!
broadcast_pre_rht_amax
),
kReturnIdentityAmax
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
all_return_pre_rht_amax
,
kReturnPreRhtAmax
,
// *2 for ping-pong
size_t
in_sh_size
=
kBuffDimX
*
kBuffDimY
*
2
*
sizeof
(
IType
);
size_t
mbar_size
=
sizeof
(
uint64_t
)
*
(
kChunkBlockXSmall
/
kBuffDimX
)
*
(
kChunkBlockYSmall
/
kBuffDimY
);
size_t
shmem_bytes
=
in_sh_size
+
mbar_size
+
kNumWarps
*
sizeof
(
float
)
*
3
;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes
=
(
shmem_bytes
+
128
);
auto
kernel
=
GraphSafeGroupHadamardAmaxTmaKernel
<
IType
,
kHadamardDimension
,
kChunkBlockYSmall
,
kChunkBlockXSmall
,
kBuffDimY
,
kBuffDimX
,
kThreadBlockX
*
kThreadsPerWarp
,
kThreadBlockY
,
kReturnPreRhtAmax
,
kReturnIdentityAmax
,
kReturnTransposedAmax
>
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_bytes
);
kernel
<<<
grid
,
block
,
shmem_bytes
,
stream
>>>
(
tensor_map_input
,
random_sign_mask
,
random_sign_mask_t
,
shape_rep
,
num_tensors
,
first_logical_dim
,
last_logical_dim
,
offsets_ptr
,
first_dims_ptr
,
amax_rowwise_ptr
,
amax_colwise_ptr
);
if
(
broadcast_pre_rht_amax
)
{
GraphSafeMultiAmaxMemcpyD2DKernelPreRHT
<<<
grid_setup_amax
,
block_setup_amax
,
0
,
stream
>>>
(
num_tensors
,
amax_rowwise_ptr
,
amax_colwise_ptr
);
})));
NVTE_CHECK_CUDA
(
cudaGetLastError
());
#else
NVTE_ERROR
(
"Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif // CUDA_VERSION >= 12080
}
}
// namespace transformer_engine
void
nvte_group_hadamard_transform_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_hadamard_transform_amax_graph_safe
);
using
namespace
transformer_engine
;
GroupedTensor
*
input_tensor
=
convertNVTEGroupedTensorCheck
(
input
);
GroupedTensor
*
output_tensor
=
convertNVTEGroupedTensorCheck
(
output
);
if
(
input_tensor
->
num_tensors
==
0
)
{
return
;
}
// Call the group tensor Hadamard transform amax implementation.
group_hadamard_transform_amax_graph_safe
(
input_tensor
,
output_tensor
,
static_cast
<
uint16_t
>
(
random_sign_mask
),
static_cast
<
uint16_t
>
(
random_sign_mask_t
),
false
,
stream
);
}
// Grouped-tensor amax without doing hadamard transform
void
nvte_group_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_amax_graph_safe
);
using
namespace
transformer_engine
;
GroupedTensor
*
input_tensor
=
convertNVTEGroupedTensorCheck
(
input
);
GroupedTensor
*
output_tensor
=
convertNVTEGroupedTensorCheck
(
output
);
if
(
input_tensor
->
num_tensors
==
0
)
{
return
;
}
group_hadamard_transform_amax_graph_safe
(
input_tensor
,
output_tensor
,
0
,
0
,
true
,
stream
);
}
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "customized_pipeline.cuh"
#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/float8.h"
#include "cutlass/float_subbyte.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/platform/platform.h"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/print_error.hpp"
namespace
transformer_engine
{
namespace
detail
{
namespace
{
using
namespace
cute
;
// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor
using
cute
::
Tensor
;
constexpr
int
kMaxTensorsPerKernel
=
64
;
constexpr
int
kNVFP4BlockSize
=
16
;
enum
ShapeRepresentation
{
SAME_BOTH_DIMS
=
0
,
VARYING_FIRST_DIM
=
1
,
VARYING_LAST_DIM
=
2
,
VARYING_BOTH_DIMS
=
3
};
__device__
__forceinline__
size_t
get_current_tensor_id
(
const
ShapeRepresentation
shape_rep
,
const
size_t
num_tensors
,
const
size_t
current_offset
,
const
size_t
first_logical_dim
,
const
size_t
last_logical_dim
,
const
int64_t
*
const
__restrict__
offsets_ptr
)
{
if
(
shape_rep
==
ShapeRepresentation
::
SAME_BOTH_DIMS
)
{
const
size_t
current_row
=
current_offset
/
last_logical_dim
;
const
size_t
rows_per_tensor
=
first_logical_dim
/
num_tensors
;
return
current_row
/
rows_per_tensor
;
}
else
{
// upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors)
size_t
low
=
0
;
size_t
hi
=
num_tensors
;
// half-open [low, hi)
while
(
low
<
hi
)
{
const
size_t
mid
=
low
+
(
hi
-
low
)
/
2
;
const
size_t
mid_offset
=
static_cast
<
size_t
>
(
offsets_ptr
[
mid
]);
if
(
mid_offset
<=
current_offset
)
{
low
=
mid
+
1
;
}
else
{
hi
=
mid
;
}
}
// low = first index where offsets[low] > current_offset (or low == num_tensors)
// id = low - 1, but need to evaluate if current_offset < offsets[0]
return
(
low
==
0
)
?
0
:
(
low
-
1
);
}
}
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
"{
\n
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;
\n
"
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;
\n
"
"}"
:
"=h"
(
output_ptr
[
0
]),
"=h"
(
output_ptr
[
1
])
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
output
;
}
CUTLASS_DEVICE
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
StochasticNumericConverter
(
cutlass
::
Array
<
float
,
16
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
4
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
16
>
;
result_type
output
;
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*
result_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
*>
(
&
output
);
cutlass
::
Array
<
float
,
8
>
const
*
source_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
float
,
8
>
const
*>
(
&
input
);
cutlass
::
Array
<
uint32_t
,
2
>
const
*
rbits_ptr
=
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
2
>
const
*>
(
&
rbits
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
result_ptr
[
i
]
=
StochasticNumericConverterBase
(
source_ptr
[
i
],
rbits_ptr
[
i
]);
}
return
output
;
}
template
<
class
ElementA
,
class
ElementB
,
class
ASmemLayout
,
class
BSmemLayout
,
class
ClusterShape
,
int
AccumulatorPipelineStageCount_
,
int
EpilogueUnrollFactor_
,
int
SchedulerPipelineStageCount_
>
struct
SharedStorage
{
static
int
constexpr
AccumulatorPipelineStageCount
=
AccumulatorPipelineStageCount_
;
static
int
constexpr
EpilogueUnrollFactor
=
EpilogueUnrollFactor_
;
using
AtomThrShapeMNK
=
cute
::
Shape
<
_1
,
_1
,
_1
>
;
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
EpilogueUnrollFactor
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineStorage
=
typename
AccumulatorPipeline
::
SharedStorage
;
static
int
constexpr
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
using
MainloopPipeline
=
cutlass
::
detail
::
CustomizedPipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
Shape
<
_1
,
_1
,
_1
>
,
AtomThrShapeMNK
>
;
using
MainloopPipelineStorage
=
typename
MainloopPipeline
::
SharedStorage
;
using
SchedPipeline
=
cutlass
::
PipelineCLCFetchAsync
<
SchedulerPipelineStageCount_
,
ClusterShape
>
;
using
SchedPipelineStorage
=
typename
SchedPipeline
::
SharedStorage
;
using
SchedThrottlePipeline
=
cutlass
::
PipelineAsync
<
SchedulerPipelineStageCount_
>
;
using
SchedThrottlePipelineStorage
=
typename
SchedThrottlePipeline
::
SharedStorage
;
struct
TensorStorage
:
cute
::
aligned_struct
<
128
,
_1
>
{
cute
::
array_aligned
<
ElementA
,
cute
::
cosize_v
<
ASmemLayout
>>
smem_A
;
cute
::
array_aligned
<
ElementB
,
cute
::
cosize_v
<
BSmemLayout
>>
smem_B
;
}
tensors
;
alignas
(
16
)
AccumulatorPipelineStorage
accumulator
;
alignas
(
16
)
MainloopPipelineStorage
mainloop
;
alignas
(
16
)
cute
::
uint64_t
tma_barrier
[
1
];
alignas
(
16
)
SchedPipelineStorage
sched
;
alignas
(
16
)
SchedThrottlePipelineStorage
sched_throttle
;
alignas
(
16
)
int32_t
atomic_tile_id
[
SchedulerPipelineStageCount_
];
alignas
(
16
)
float
global_a_amax
[
kMaxTensorsPerKernel
];
alignas
(
16
)
float
global_d_amax
[
kMaxTensorsPerKernel
];
uint32_t
atomic_tile_counter
[
SchedulerPipelineStageCount_
];
uint32_t
tmem_base_ptr
;
};
// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support
template
<
class
MShape
,
class
NShape
,
class
KShape
,
class
ClusterShape
,
class
ClusterTileShape
,
class
TA
,
class
AStride
,
class
ASmemLayout
,
class
TmaLoadA
,
class
TB
,
class
BStride
,
class
BSmemLayout
,
class
TmaLoadB
,
class
TD
,
class
DStride
,
class
DSmemLayout
,
class
TSFD
,
class
TSFDLayout
,
class
TQA
,
class
QAStride
,
class
TSFA
,
class
TSFALayout
,
class
TiledMMA
,
int
AccumulatorPipelineStageCount_
,
int
SchedulerPipelineStageCount_
,
bool
kEnableStochasticRounding_
=
false
,
bool
kEnableRHTColQuant_
=
true
,
bool
kEnableRowQuant_
=
true
,
bool
kEnableSwizzleSFOutput_
=
false
,
bool
kUseFastMath_
=
false
>
__launch_bounds__
(
512
,
1
)
__global__
static
void
group_row_col_rht_gemm_device_graph_safe
(
MShape
M
,
NShape
packed_N
,
KShape
K
,
ClusterShape
cluster_shape
,
ClusterTileShape
cluster_tile
,
TA
const
*
A
,
AStride
dA
,
ASmemLayout
sAlayout
,
CUTE_GRID_CONSTANT
TmaLoadA
const
tma_load_a
,
TB
const
*
B
,
BStride
dB
,
BSmemLayout
sBlayout
,
CUTE_GRID_CONSTANT
TmaLoadB
const
tma_load_b
,
TQA
*
QA
,
QAStride
dQA
,
TSFA
*
SFA
,
TSFALayout
sfa_layout
,
TQA
*
QA_COLWISE
,
TSFA
*
SFA_COLWISE
,
float
*
amax_rowwise
,
float
*
amax_colwise
,
const
int64_t
*
offsets
,
const
int64_t
*
first_dims
,
size_t
num_tensors
,
ShapeRepresentation
shape_rep
,
uint32_t
*
tile_scheduler_workspace
,
TiledMMA
mma
,
const
size_t
*
rng_state
)
{
using
namespace
cute
;
// Abort immediately if compilation is not supported
constexpr
bool
is_blackwell_arch
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
!
is_blackwell_arch
)
{
NVTE_DEVICE_ERROR
(
"group_row_col_rht_gemm_device_graph_safe is only supported on Blackwell "
"with architecture-specific compilation. "
"Try recompiling with sm_100a or similar."
);
return
;
}
static_assert
(
kEnableRHTColQuant_
||
kEnableRowQuant_
,
"group_row_col_rht_gemm_device_graph_safe must generate row-wise "
"and/or column-wise output."
);
#if !defined(CUTLASS_ARCH_CLC_ENABLED)
CUTLASS_NOT_IMPLEMENTED
();
return
;
#endif
using
X
=
Underscore
;
// Accumulator data type for main computation
using
ElementAccumulator
=
float
;
static
int
constexpr
K_PIPE_MAX
=
size
<
3
>
(
ASmemLayout
{});
using
AtomThrShapeMNK
=
Shape
<
decltype
(
shape
<
0
>
(
typename
TiledMMA
::
ThrLayoutVMNK
{})),
_1
,
_1
>
;
static
uint32_t
constexpr
kTmaTransactionBytes
=
cutlass
::
bits_to_bytes
(
size
(
AtomThrShapeMNK
{})
*
cosize
(
take
<
0
,
3
>
(
ASmemLayout
{}))
*
cute
::
sizeof_bits_v
<
TA
>
);
static
constexpr
bool
kEnableStochasticRounding
=
kEnableStochasticRounding_
;
static
constexpr
bool
kEnableRHTColQuant
=
kEnableRHTColQuant_
;
static
constexpr
bool
kEnableRowQuant
=
kEnableRowQuant_
;
static
constexpr
bool
kEnableSwizzleSFOutput
=
kEnableSwizzleSFOutput_
;
static
constexpr
bool
kUseFastMath
=
kUseFastMath_
;
// Constant for RHT tensor processing (tile size etc)
static
int
constexpr
RhtTensorSize
=
16
;
// Get the total number of tokens to process
// Note that here M is the hidden size, which is the last logical dimension of the input tensor x
// The kernel is designed in column major, so M is the hidden size
size_t
sum_token_dims
=
offsets
[
num_tensors
]
/
M
;
// Transaction bytes for TMA transfer on RHT tensor blocks
static
int
constexpr
kTmaRhtTensorTransactionBytes
=
cutlass
::
bits_to_bytes
(
RhtTensorSize
*
RhtTensorSize
*
cute
::
sizeof_bits_v
<
TB
>
);
static
int
constexpr
AccumulatorPipelineStageCount
=
AccumulatorPipelineStageCount_
;
static
int
constexpr
SchedulerPipelineStageCount
=
SchedulerPipelineStageCount_
;
// Mainloop pipeline stage calculation, vectorization parameters for scaling factors
static
int
constexpr
MainloopPipelineStageCount
=
size
<
3
>
(
ASmemLayout
{});
static
int
constexpr
SFVecSize
=
16
;
// Swizzle output layout for scaling factor arrays
using
SwizzledSFALayoutAtom
=
cutlass
::
detail
::
Sm1xxBlockScaledOutputConfig
<
SFVecSize
,
UMMA
::
Major
::
MN
>::
SfAtom
;
using
SwizzledSFDLayoutAtom
=
cutlass
::
detail
::
Sm1xxBlockScaledOutputConfig
<
SFVecSize
,
UMMA
::
Major
::
K
>::
SfAtom
;
// Mainloop pipeline types for TMA async execution and epilogue cluster scheduling
using
MainloopPipeline
=
cutlass
::
detail
::
CustomizedPipelineTmaUmmaAsync
<
MainloopPipelineStageCount
,
ClusterShape
,
AtomThrShapeMNK
>
;
using
MainloopPipelineState
=
typename
MainloopPipeline
::
PipelineState
;
using
SchedPipeline
=
cutlass
::
PipelineCLCFetchAsync
<
SchedulerPipelineStageCount
,
ClusterShape
>
;
using
SchedPipelineState
=
typename
SchedPipeline
::
PipelineState
;
using
SchedThrottlePipeline
=
cutlass
::
PipelineAsync
<
SchedulerPipelineStageCount
>
;
using
SchedThrottlePipelineState
=
typename
SchedThrottlePipeline
::
PipelineState
;
static_assert
(
ClusterShape
{}
==
Shape
<
_1
,
_1
,
_1
>
{},
"ClusterShape must be Shape<_1,_1,_1>"
);
using
TmemAllocator
=
cute
::
TMEM
::
Allocator1Sm
;
static
int
constexpr
VectorSize
=
RhtTensorSize
;
// Compile-time safety: static shapes required for shared memory layouts
CUTE_STATIC_ASSERT
(
is_static
<
ASmemLayout
>::
value
);
CUTE_STATIC_ASSERT
(
is_static
<
BSmemLayout
>::
value
);
// CUTE_STATIC_ASSERT(is_static<DSmemLayout>::value);
auto
cluster_size
=
size
<
0
>
(
cluster_shape
);
auto
mainloop_tiler
=
Shape
<
_128
,
_16
,
_128
>
{};
auto
epilogue_tiler
=
Shape
<
_128
,
_128
,
_128
>
{};
static
int
constexpr
EpilogueUnrollFactor
=
size
<
2
>
(
epilogue_tiler
)
/
size
<
2
>
(
cluster_tile
);
// Get the appropriate blocks for this Cluster
dim3
cluster_coord_in_grid
=
cluster_id_in_grid
();
// Total number of k-tiles
int
const
K_TILE_MAX
=
min
(
packed_N
,
K
)
/
size
<
2
>
(
epilogue_tiler
);
struct
TileScheduler
{
uint32_t
tiles_in_m
=
0
;
uint32_t
tiles_in_n
=
0
;
uint32_t
linear_idx
=
0
;
uint32_t
next_linear_idx
=
0
;
uint32_t
start_idx
=
0
;
uint32_t
tile_m_idx
=
0
;
uint32_t
tile_n_idx
=
0
;
int
k_tile_max
=
0
;
uint32_t
*
atomic_tile_index_
;
uint32_t
*
smem_tile_counter
;
uint32_t
atomic_offset
;
cutlass
::
FastDivmodU64
divmod_tiles_in_m
;
CUTLASS_DEVICE
TileScheduler
(
uint32_t
tiles_m
,
uint32_t
tiles_n
,
int
kmax
,
uint32_t
*
atomic_tile_index
,
uint32_t
*
smem_tile_counter
)
:
tiles_in_m
(
tiles_m
),
tiles_in_n
(
tiles_n
),
linear_idx
(
blockIdx
.
x
),
next_linear_idx
(
blockIdx
.
x
),
start_idx
(
blockIdx
.
x
),
k_tile_max
(
kmax
),
atomic_tile_index_
(
atomic_tile_index
),
smem_tile_counter
(
smem_tile_counter
),
atomic_offset
(
gridDim
.
x
),
divmod_tiles_in_m
(
uint64_t
(
tiles_m
))
{
update_tile_idx
();
}
CUTLASS_DEVICE
void
update_tile_idx
()
{
uint64_t
q
,
r
;
divmod_tiles_in_m
(
q
,
r
,
uint64_t
(
linear_idx
));
tile_m_idx
=
static_cast
<
uint32_t
>
(
r
);
tile_n_idx
=
static_cast
<
uint32_t
>
(
q
)
*
uint32_t
(
k_tile_max
);
}
CUTLASS_DEVICE
uint32_t
tile_m
()
const
{
return
tile_m_idx
;
}
CUTLASS_DEVICE
uint32_t
tile_n_base
()
const
{
return
tile_n_idx
;
}
CUTLASS_DEVICE
uint32_t
tiles_m
()
const
{
return
tiles_in_m
;
}
CUTLASS_DEVICE
uint32_t
tiles_n
()
const
{
return
tiles_in_n
;
}
CUTLASS_DEVICE
bool
is_valid
()
const
{
return
cute
::
elem_less
(
cute
::
make_coord
(
tile_m
(),
tile_n_base
()),
cute
::
make_coord
(
tiles_in_m
,
tiles_in_n
));
}
CUTLASS_DEVICE
bool
is_first_wave
()
const
{
return
linear_idx
==
start_idx
;
}
CUTLASS_DEVICE
uint32_t
get_linear_tile_idx
()
const
{
return
linear_idx
;
}
// Fetch a new tile_id using atomics.
CUTLASS_DEVICE
uint32_t
fetch_tile_id_counter
(
int
pred
)
{
uint32_t
tile_id_counter
=
0
;
asm
volatile
(
"{
\n\t
"
".reg .pred p;
\n\t
"
"setp.eq.u32 p, %2, 1;
\n\t
"
"@p atom.global.add.u32 %0, [%1], 1;
\n\t
"
"}"
:
"=r"
(
tile_id_counter
)
:
"l"
(
atomic_tile_index_
),
"r"
(
pred
));
return
tile_id_counter
;
}
CUTLASS_DEVICE
auto
fetch_next_work
(
SchedPipeline
&
sched_pipeline
,
SchedPipelineState
sched_pipeline_consumer_state
)
{
sched_pipeline
.
consumer_wait
(
sched_pipeline_consumer_state
);
next_linear_idx
=
smem_tile_counter
[
sched_pipeline_consumer_state
.
index
()];
cutlass
::
arch
::
fence_view_async_shared
();
sched_pipeline
.
consumer_release
(
sched_pipeline_consumer_state
);
return
;
}
CUTLASS_DEVICE
auto
advance_to_next_work
(
SchedPipeline
&
sched_pipeline
,
SchedPipelineState
sched_pipeline_producer_state
)
{
uint32_t
mbarrier_addr
=
sched_pipeline
.
producer_get_barrier
(
sched_pipeline_producer_state
);
// Wait for clcID buffer to become empty with a flipped phase
sched_pipeline
.
producer_acquire
(
sched_pipeline_producer_state
);
auto
is_leading_thread
=
cute
::
elect_one_sync
();
uint32_t
tile_id_counter
=
fetch_tile_id_counter
(
is_leading_thread
)
+
atomic_offset
;
uint32_t
smem_addr
=
cute
::
cast_smem_ptr_to_uint
(
&
smem_tile_counter
[
sched_pipeline_producer_state
.
index
()]);
if
(
is_leading_thread
)
{
cute
::
store_shared_remote
(
tile_id_counter
,
smem_addr
,
mbarrier_addr
,
0
);
}
++
sched_pipeline_producer_state
;
return
sched_pipeline_producer_state
;
}
CUTLASS_DEVICE
auto
update_work_tile_info
()
{
linear_idx
=
next_linear_idx
;
update_tile_idx
();
return
;
}
};
// Allocate and alias shared memory to the kernel's shared storage type
extern
__shared__
char
shared_memory
[];
using
SharedStorage
=
SharedStorage
<
TA
,
TB
,
ASmemLayout
,
BSmemLayout
,
ClusterShape
,
AccumulatorPipelineStageCount
,
EpilogueUnrollFactor
,
SchedulerPipelineStageCount
>
;
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
shared_memory
);
// Compute the number of tiles in M and N after tiling and assign scheduler
uint32_t
tiles_in_m
=
uint32_t
(
size
(
ceil_div
(
M
,
size
<
0
>
(
cluster_tile
))));
uint32_t
tiles_in_n
=
uint32_t
(
size
(
ceil_div
(
sum_token_dims
,
size
<
2
>
(
epilogue_tiler
))));
TileScheduler
scheduler
(
tiles_in_m
,
tiles_in_n
,
K_TILE_MAX
,
tile_scheduler_workspace
,
shared_storage
.
atomic_tile_counter
);
int
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
// Shapes for accumulated tiles in mainloop and epilogue
auto
acc_shape_mma
=
make_shape
(
take
<
0
,
2
>
(
mainloop_tiler
),
_1
{},
_1
{});
auto
acc_shape_epilogue
=
make_shape
(
take
<
0
,
2
>
(
epilogue_tiler
),
_1
{},
_1
{});
// Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended
auto
acc_mainloop_pipelined_shape
=
append
(
acc_shape_mma
,
Int
<
AccumulatorPipelineStageCount
>
{});
auto
bulk_tmem_mma
=
TiledMMA
::
make_fragment_C
(
acc_mainloop_pipelined_shape
);
// Number of threads assigned for various epilogue roles depending on quantization settings
static
int
constexpr
NumEpilogueColQuantThreadCount
=
kEnableRHTColQuant
?
128
:
0
;
static
int
constexpr
NumEpilogueRowQuantThreadCount
=
kEnableRowQuant
?
256
:
0
;
static
int
constexpr
NumMmaThreadCount
=
kEnableRHTColQuant
?
32
:
0
;
static
int
constexpr
NumMmaIssueThreadCount
=
kEnableRHTColQuant
?
1
:
0
;
static
int
constexpr
NumSchedThreads
=
32
;
static
int
constexpr
NumMainloopLoadThreads
=
32
;
static
int
constexpr
NumEpilogueThreads
=
NumEpilogueColQuantThreadCount
+
NumEpilogueRowQuantThreadCount
;
TmemAllocator
tmem_allocator
{};
cutlass
::
arch
::
NamedBarrier
tmem_allocation_result_barrier
(
NumMmaThreadCount
+
NumEpilogueColQuantThreadCount
,
cutlass
::
arch
::
ReservedNamedBarriers
::
TmemAllocBarrier
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
// warp assignment
bool
is_mma_warp
=
(
warp_idx
==
0
);
bool
is_dma_warp
=
(
warp_idx
==
1
);
bool
is_sched_warp
=
(
warp_idx
==
2
);
bool
is_epilogue_col_quant_warp
=
(
warp_idx
>=
4
&&
warp_idx
<=
7
);
bool
is_epilogue_row_quant_warp
=
(
warp_idx
>=
8
&&
warp_idx
<=
15
);
typename
MainloopPipeline
::
Params
mainloop_pipeline_params
;
if
(
is_dma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_mma_warp
)
{
mainloop_pipeline_params
.
role
=
MainloopPipeline
::
ThreadCategory
::
Consumer
;
}
mainloop_pipeline_params
.
is_leader
=
cute
::
elect_one_sync
()
&&
is_dma_warp
;
mainloop_pipeline_params
.
transaction_bytes
=
kTmaTransactionBytes
;
mainloop_pipeline_params
.
initializing_warp
=
0
;
mainloop_pipeline_params
.
num_consumers
=
NumEpilogueRowQuantThreadCount
+
NumMmaIssueThreadCount
;
MainloopPipeline
mainloop_pipeline
(
shared_storage
.
mainloop
,
mainloop_pipeline_params
,
cluster_shape
,
cute
::
true_type
{},
// Perform barrier init
cute
::
true_type
{});
// Delay mask calculation
MainloopPipelineState
mainloop_pipe_consumer_state
;
MainloopPipelineState
mainloop_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
MainloopPipeline
>
();
using
AccumulatorPipeline
=
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
EpilogueUnrollFactor
,
AtomThrShapeMNK
>
;
using
AccumulatorPipelineState
=
typename
AccumulatorPipeline
::
PipelineState
;
using
AccumulatorPipelineInitBarriers
=
cute
::
bool_constant
<
kEnableRHTColQuant
>
;
AccumulatorPipelineState
accumulator_pipe_consumer_state
;
AccumulatorPipelineState
accumulator_pipe_producer_state
=
cutlass
::
make_producer_start_state
<
AccumulatorPipeline
>
();
typename
AccumulatorPipeline
::
Params
accumulator_pipeline_params
;
if
(
is_mma_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_epilogue_col_quant_warp
)
{
accumulator_pipeline_params
.
role
=
AccumulatorPipeline
::
ThreadCategory
::
Consumer
;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params
.
producer_arv_count
=
1
;
accumulator_pipeline_params
.
consumer_arv_count
=
size
(
AtomThrShapeMNK
{})
*
NumEpilogueColQuantThreadCount
;
accumulator_pipeline_params
.
initializing_warp
=
1
;
AccumulatorPipeline
accumulator_pipeline
(
shared_storage
.
accumulator
,
accumulator_pipeline_params
,
cluster_shape
,
AccumulatorPipelineInitBarriers
{},
cute
::
true_type
{});
// Delay mask calculation
typename
SchedPipeline
::
Params
sched_pipeline_params
;
if
(
is_sched_warp
)
{
sched_pipeline_params
.
role
=
SchedPipeline
::
ThreadCategory
::
ProducerConsumer
;
}
else
{
sched_pipeline_params
.
role
=
SchedPipeline
::
ThreadCategory
::
Consumer
;
}
sched_pipeline_params
.
producer_blockid
=
0
;
sched_pipeline_params
.
producer_arv_count
=
1
;
sched_pipeline_params
.
consumer_arv_count
=
NumSchedThreads
+
cluster_size
*
(
NumMainloopLoadThreads
+
NumEpilogueThreads
+
NumMmaThreadCount
);
sched_pipeline_params
.
transaction_bytes
=
sizeof
(
uint32_t
);
sched_pipeline_params
.
initializing_warp
=
3
;
SchedPipeline
sched_pipeline
(
shared_storage
.
sched
,
sched_pipeline_params
,
cluster_shape
);
SchedPipelineState
sched_pipeline_consumer_state
;
SchedPipelineState
sched_pipeline_producer_state
=
cutlass
::
make_producer_start_state
<
SchedPipeline
>
();
typename
SchedThrottlePipeline
::
Params
sched_throttle_pipeline_params
;
if
(
is_dma_warp
)
{
sched_throttle_pipeline_params
.
role
=
SchedThrottlePipeline
::
ThreadCategory
::
Producer
;
}
if
(
is_sched_warp
)
{
sched_throttle_pipeline_params
.
role
=
SchedThrottlePipeline
::
ThreadCategory
::
Consumer
;
}
sched_throttle_pipeline_params
.
producer_arv_count
=
NumMainloopLoadThreads
;
sched_throttle_pipeline_params
.
consumer_arv_count
=
NumSchedThreads
;
sched_throttle_pipeline_params
.
dst_blockid
=
0
;
sched_throttle_pipeline_params
.
initializing_warp
=
4
;
SchedThrottlePipeline
sched_throttle_pipeline
(
shared_storage
.
sched_throttle
,
sched_throttle_pipeline_params
);
SchedThrottlePipelineState
sched_pipeline_throttle_consumer_state
;
SchedThrottlePipelineState
sched_pipeline_throttle_producer_state
=
cutlass
::
make_producer_start_state
<
SchedThrottlePipeline
>
();
if
(
warp_idx
==
2
&&
elect_one_sync
())
{
cute
::
initialize_barrier
(
shared_storage
.
tma_barrier
[
0
],
/* num_threads */
1
);
}
__syncthreads
();
// Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer
if
(
is_dma_warp
)
{
// Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access).
cutlass
::
arch
::
warpgroup_reg_dealloc
<
32
>
();
// Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory.
Tensor
mA
=
tma_load_a
.
get_tma_tensor
(
make_shape
(
M
,
packed_N
));
Tensor
mB
=
tma_load_b
.
get_tma_tensor
(
make_shape
(
RhtTensorSize
,
RhtTensorSize
));
// Partition tensors for tiling according to the mainloop and cluster tilers.
Tensor
gA_mk
=
local_tile
(
mA
,
mainloop_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
gB_nk
=
local_tile
(
mB
,
cluster_tile
,
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,k)
// Shared memory tensors for pipeline
Tensor
tCsA
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_A
.
data
()),
sAlayout
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCsB
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_B
.
data
()),
sBlayout
);
// (MMA,MMA_N,MMA_K,PIPE)
// Determine warp/tile positioning
int
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
ThrMMA
thr_mma
=
mma
.
get_slice
(
block_rank_in_cluster
);
// blk idx
// Partition global to local fragments for A and B
Tensor
tCgA
=
thr_mma
.
partition_A
(
gA_mk
);
// (MMA,MMA_M,MMA_K,k)
Tensor
tCgB
=
thr_mma
.
partition_B
(
gB_nk
);
// (MMA,MMA_N,MMA_K,k)
Layout
cta_layout_mnk
=
make_layout
(
cluster_shape
);
Layout
cta_layout_vmnk
=
tiled_divide
(
cta_layout_mnk
,
make_tile
(
typename
TiledMMA
::
AtomThrID
{}));
auto
cta_coord_vmnk
=
cta_layout_vmnk
.
get_flat_coord
(
block_rank_in_cluster
);
auto
[
tAgA
,
tAsA
]
=
tma_partition
(
tma_load_a
,
get
<
2
>
(
cta_coord_vmnk
),
make_layout
(
size
<
2
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsA
),
group_modes
<
0
,
3
>
(
tCgA
));
auto
[
tBgB
,
tBsB
]
=
tma_partition
(
tma_load_b
,
get
<
1
>
(
cta_coord_vmnk
),
make_layout
(
size
<
1
>
(
cta_layout_vmnk
)),
group_modes
<
0
,
3
>
(
tCsB
),
group_modes
<
0
,
3
>
(
tCgB
));
uint16_t
tma_mcast_mask_a
=
create_tma_multicast_mask
<
2
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
uint16_t
tma_mcast_mask_b
=
create_tma_multicast_mask
<
1
>
(
cta_layout_vmnk
,
cta_coord_vmnk
);
if
constexpr
(
kEnableRHTColQuant
)
{
if
(
elect_one_sync
())
{
cute
::
set_barrier_transaction_bytes
(
shared_storage
.
tma_barrier
[
0
],
kTmaRhtTensorTransactionBytes
);
copy
(
tma_load_b
.
with
(
shared_storage
.
tma_barrier
[
0
],
tma_mcast_mask_b
),
tBgB
(
_
,
0
,
0
),
tBsB
(
_
,
0
));
}
}
do
{
// is_first_wave indicates whether this scheduler wave is the first among a group.
bool
is_first_wave
=
scheduler
.
is_first_wave
();
uint32_t
skip_wait
=
is_first_wave
;
auto
tAgA_mk
=
tAgA
(
_
,
scheduler
.
tile_m
(),
_
);
int
k_tile
=
0
;
sched_throttle_pipeline
.
producer_acquire
(
sched_pipeline_throttle_producer_state
);
sched_throttle_pipeline
.
producer_commit
(
sched_pipeline_throttle_producer_state
);
++
sched_pipeline_throttle_producer_state
;
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile
<
K_TILE_MAX
&&
k_tile
+
scheduler
.
tile_n_base
()
<
scheduler
.
tiles_n
())
{
int
k_tile_idx_n
=
scheduler
.
tile_n_base
()
+
k_tile
;
++
k_tile
;
skip_wait
=
(
is_first_wave
&&
k_tile
<
MainloopPipelineStageCount
);
mainloop_pipeline
.
producer_acquire
(
mainloop_pipe_producer_state
);
using
BarrierType
=
typename
MainloopPipeline
::
ProducerBarrierType
;
BarrierType
*
tma_barrier
=
mainloop_pipeline
.
producer_get_barrier
(
mainloop_pipe_producer_state
);
int
write_stage
=
mainloop_pipe_producer_state
.
index
();
++
mainloop_pipe_producer_state
;
if
(
cute
::
elect_one_sync
())
{
copy
(
tma_load_a
.
with
(
*
tma_barrier
,
tma_mcast_mask_a
),
tAgA_mk
(
_
,
k_tile_idx_n
),
tAsA
(
_
,
write_stage
));
}
}
scheduler
.
fetch_next_work
(
sched_pipeline
,
sched_pipeline_consumer_state
);
++
sched_pipeline_consumer_state
;
scheduler
.
update_work_tile_info
();
// scheduler.advance();
}
while
(
scheduler
.
is_valid
());
mainloop_pipeline
.
producer_tail
(
mainloop_pipe_producer_state
);
}
else
if
(
is_mma_warp
)
{
// This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform.
cutlass
::
arch
::
warpgroup_reg_dealloc
<
32
>
();
if
constexpr
(
kEnableRHTColQuant
)
{
// Setup shared memory fragments for A and B tiles.
Tensor
tCsA
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_A
.
data
()),
sAlayout
);
// (MMA,MMA_M,MMA_N,PIPE)
Tensor
tCsB
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_B
.
data
()),
sBlayout
);
// (MMA,MMA_N,MMA_K,PIPE)
int
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
ThrMMA
thr_mma
=
mma
.
get_slice
(
block_rank_in_cluster
);
// blk idx
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor
tCrA
=
thr_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCrB
=
thr_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_M,MMA_K,PIPE)
mma
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
tmem_allocation_result_barrier
.
arrive
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_mma
.
data
()
=
tmem_base_ptr
;
// Wait until the B (Hadamard) tensor copy is complete
cute
::
wait_barrier
(
shared_storage
.
tma_barrier
[
0
],
0
/*tma_phase_bit*/
);
do
{
uint32_t
skip_wait
=
K_TILE_MAX
<=
0
;
auto
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
scheduler
.
fetch_next_work
(
sched_pipeline
,
sched_pipeline_consumer_state
);
++
sched_pipeline_consumer_state
;
CUTLASS_PRAGMA_NO_UNROLL
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
scheduler
.
tile_n_base
()
<
scheduler
.
tiles_n
();)
{
mainloop_pipeline
.
consumer_wait
(
mainloop_pipe_consumer_state
,
barrier_token
);
int
read_stage
=
mainloop_pipe_consumer_state
.
index
();
auto
tCrA_mk
=
tCrA
(
_
,
_
,
_
,
read_stage
);
auto
tCrB_nk
=
tCrB
(
_
,
_
,
0
,
0
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
)
/
EpilogueUnrollFactor
;
++
k_block
)
{
int
accumulator_k_block
=
accumulator_pipe_producer_state
.
index
()
*
EpilogueUnrollFactor
;
int
tCrA_k_block
=
k_block
*
EpilogueUnrollFactor
;
accumulator_pipeline
.
producer_acquire
(
accumulator_pipe_producer_state
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
EpilogueUnrollFactor
;
i
++
)
{
auto
accumulators
=
bulk_tmem_mma
(
_
,
_
,
_
,
accumulator_k_block
+
i
);
gemm
(
mma
,
tCrA_mk
(
_
,
_
,
tCrA_k_block
+
i
),
tCrB_nk
,
accumulators
);
}
accumulator_pipeline
.
producer_commit
(
accumulator_pipe_producer_state
);
++
accumulator_pipe_producer_state
;
}
auto
curr_mainloop_pipe_consumer_state
=
mainloop_pipe_consumer_state
;
++
mainloop_pipe_consumer_state
;
++
k_tile
;
skip_wait
=
k_tile
>=
K_TILE_MAX
;
mainloop_pipeline
.
umma_consumer_release
(
curr_mainloop_pipe_consumer_state
);
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
,
skip_wait
);
}
scheduler
.
update_work_tile_info
();
}
while
(
scheduler
.
is_valid
());
tmem_allocator
.
release_allocation_lock
();
accumulator_pipeline
.
producer_tail
(
accumulator_pipe_producer_state
);
tmem_allocator
.
free
(
tmem_base_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
else
if
(
is_sched_warp
)
{
// Scheduler warp manages tile assignment and pipeline progress for warps
cutlass
::
arch
::
warpgroup_reg_dealloc
<
32
>
();
do
{
sched_throttle_pipeline
.
consumer_wait
(
sched_pipeline_throttle_consumer_state
);
sched_throttle_pipeline
.
consumer_release
(
sched_pipeline_throttle_consumer_state
);
++
sched_pipeline_throttle_consumer_state
;
sched_pipeline_producer_state
=
scheduler
.
advance_to_next_work
(
sched_pipeline
,
sched_pipeline_producer_state
);
scheduler
.
fetch_next_work
(
sched_pipeline
,
sched_pipeline_consumer_state
);
++
sched_pipeline_consumer_state
;
scheduler
.
update_work_tile_info
();
}
while
(
scheduler
.
is_valid
());
}
else
if
(
is_epilogue_col_quant_warp
)
{
// Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage,
// and writing result tensors/scales to global memory.
cutlass
::
arch
::
warpgroup_reg_alloc
<
192
>
();
if
constexpr
(
kEnableRHTColQuant
)
{
using
TMEM_LOAD_NEW
=
cute
::
SM100
::
TMEM
::
LOAD
::
SM100_TMEM_LOAD_32dp32b64x
;
auto
acc_epilogue_pipelined_shape
=
append
(
acc_shape_epilogue
,
Int
<
AccumulatorPipelineStageCount
/
EpilogueUnrollFactor
>
{});
auto
bulk_tmem_epilogue_layout
=
make_layout
(
acc_epilogue_pipelined_shape
,
make_stride
(
stride
<
0
>
(
bulk_tmem_mma
),
Int
<
0
>
{},
Int
<
0
>
{},
size
<
1
>
(
epilogue_tiler
)));
auto
bulk_tmem_epilogue
=
make_tensor
(
make_tmem_ptr
<
uint32_t
>
(),
bulk_tmem_epilogue_layout
);
// Use 256-bit fragments for aligned bulk stores
static
int
constexpr
FragmentSize
=
256
/
sizeof_bits_v
<
TD
>
;
// Wait for TMEM allocation for this pipeline to finish
tmem_allocation_result_barrier
.
arrive_and_wait
();
uint32_t
tmem_base_ptr
=
shared_storage
.
tmem_base_ptr
;
bulk_tmem_epilogue
.
data
()
=
tmem_base_ptr
;
int
global_thread_idx
=
threadIdx
.
x
;
int
local_thread_idx
=
global_thread_idx
%
cutlass
::
NumThreadsPerWarpGroup
;
// g2s load all global_d_amax
CUTLASS_PRAGMA_NO_UNROLL
for
(
int
g
=
local_thread_idx
;
g
<
num_tensors
;
g
+=
NumEpilogueColQuantThreadCount
)
{
shared_storage
.
global_d_amax
[
g
]
=
__ldg
(
reinterpret_cast
<
float
*>
(
amax_colwise
+
g
));
}
size_t
rng_seed
=
0
;
size_t
rng_offset
=
0
;
// Setup RNG for stochastic rounding
if
constexpr
(
kEnableStochasticRounding
)
{
rng_seed
=
rng_state
!=
nullptr
?
__ldg
(
rng_state
)
:
0
;
rng_offset
=
rng_state
!=
nullptr
?
__ldg
(
rng_state
+
1
)
:
0
;
}
// TODO(zhongbo): double check the logic here
int
group_idx
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
(
scheduler
.
tile_n_base
()
*
size
<
1
>
(
epilogue_tiler
))
*
M
,
packed_N
,
M
,
offsets
);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout
sfd_layout
;
int
cur_N
=
static_cast
<
int
>
(
first_dims
[
group_idx
]);
if
constexpr
(
kEnableSwizzleSFOutput
)
{
sfd_layout
=
tile_to_shape
(
SwizzledSFDLayoutAtom
{},
make_shape
(
M
,
cur_N
),
Step
<
_2
,
_1
>
{});
}
else
{
sfd_layout
=
make_layout
(
make_shape
(
M
,
make_shape
(
Int
<
SFVecSize
>
{},
cur_N
/
SFVecSize
)),
make_stride
(
cur_N
/
SFVecSize
,
make_stride
(
_0
{},
_1
{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
Tensor
mD
=
make_tensor
(
cute
::
subbyte_iterator
<
TD
>
(
reinterpret_cast
<
TD
*>
(
reinterpret_cast
<
char
*>
(
QA_COLWISE
)
+
offsets
[
group_idx
]
/
2
)),
make_shape
(
M
,
cur_N
),
DStride
{});
// (M,packed_N)
Tensor
gD_mn
=
local_tile
(
mD
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
// for every tensor [x, y] row major, x y both a multiple of 128
// both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3
Tensor
mSFD
=
make_tensor
(
make_gmem_ptr
<
TSFD
>
(
reinterpret_cast
<
TSFD
*>
(
reinterpret_cast
<
char
*>
(
SFA_COLWISE
)
+
offsets
[
group_idx
]
/
kNVFP4BlockSize
)),
sfd_layout
);
Tensor
gSFD_mn
=
local_tile
(
mSFD
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
Tensor
gD_mn_view
=
tiled_divide
(
gD_mn
,
take
<
0
,
2
>
(
epilogue_tiler
));
// Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors
auto
tiled_t2r
=
make_tmem_copy
(
TMEM_LOAD_NEW
{},
bulk_tmem_epilogue
(
_
,
_
,
_
,
_0
{}));
auto
tiled_r2g
=
make_tiled_copy_D
(
Copy_Atom
<
SM100_STORE_256bit_CACHE_NOALLOCATION
,
TD
>
{},
tiled_t2r
);
auto
thr_t2r
=
tiled_t2r
.
get_slice
(
local_thread_idx
);
auto
thr_r2g
=
tiled_r2g
.
get_slice
(
local_thread_idx
);
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumEpilogueColQuantThreadCount
,
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static
constexpr
float
fp4_max
=
6.0
f
;
static
constexpr
float
fp8_max
=
448.0
f
;
static
constexpr
float
fp4_max_inv
=
1.0
f
/
fp4_max
;
float
c_global_amax_val
=
shared_storage
.
global_d_amax
[
group_idx
];
float
global_encode_scale
=
c_global_amax_val
>
0.0
f
?
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
(
fp8_max
*
fp4_max
)
/
c_global_amax_val
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
())
:
1.0
f
;
float
global_decode_scale
=
1.0
f
/
global_encode_scale
;
// Scaling factor for fast math path
float
global_encode_scale_multiplier
=
1.0
f
;
if
constexpr
(
kUseFastMath
)
{
global_encode_scale_multiplier
=
global_encode_scale
*
fp4_max_inv
;
}
do
{
scheduler
.
fetch_next_work
(
sched_pipeline
,
sched_pipeline_consumer_state
);
++
sched_pipeline_consumer_state
;
CUTLASS_PRAGMA_NO_UNROLL
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
scheduler
.
tile_n_base
()
<
scheduler
.
tiles_n
();
++
k_tile
)
{
int
global_tile_n_offset
=
(
scheduler
.
tile_n_base
()
+
k_tile
)
*
size
<
1
>
(
epilogue_tiler
);
// TODO(zhongbo): double check the logic here
int
cur_group_idx
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
global_tile_n_offset
*
M
,
packed_N
,
M
,
offsets
);
if
(
cur_group_idx
!=
group_idx
)
{
group_idx
=
cur_group_idx
;
c_global_amax_val
=
shared_storage
.
global_d_amax
[
group_idx
];
// update amax
global_encode_scale
=
c_global_amax_val
>
0.0
f
?
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
(
fp8_max
*
fp4_max
)
/
c_global_amax_val
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
())
:
1.0
f
;
global_decode_scale
=
1.0
f
/
global_encode_scale
;
if
constexpr
(
kUseFastMath
)
{
global_encode_scale_multiplier
=
global_encode_scale
*
fp4_max_inv
;
}
// TODO(zhongbo): double check the logic here
cur_N
=
first_dims
[
group_idx
];
if
constexpr
(
kEnableSwizzleSFOutput
)
{
sfd_layout
=
tile_to_shape
(
SwizzledSFDLayoutAtom
{},
make_shape
(
M
,
cur_N
),
Step
<
_2
,
_1
>
{});
}
else
{
sfd_layout
=
make_layout
(
make_shape
(
M
,
make_shape
(
Int
<
SFVecSize
>
{},
cur_N
/
SFVecSize
)),
make_stride
(
cur_N
/
SFVecSize
,
make_stride
(
_0
{},
_1
{})));
}
// update tensor
mD
=
make_tensor
(
cute
::
subbyte_iterator
<
TD
>
(
reinterpret_cast
<
TD
*>
(
reinterpret_cast
<
char
*>
(
QA_COLWISE
)
+
offsets
[
group_idx
]
/
2
)),
make_shape
(
M
,
cur_N
),
DStride
{});
gD_mn
=
local_tile
(
mD
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
mSFD
=
make_tensor
(
make_gmem_ptr
<
TSFD
>
(
reinterpret_cast
<
TSFD
*>
(
reinterpret_cast
<
char
*>
(
SFA_COLWISE
)
+
offsets
[
group_idx
]
/
kNVFP4BlockSize
)),
sfd_layout
);
gSFD_mn
=
local_tile
(
mSFD
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N)
gD_mn_view
=
tiled_divide
(
gD_mn
,
take
<
0
,
2
>
(
epilogue_tiler
));
}
int
group_start_offset
=
offsets
[
group_idx
]
/
M
;
int
local_tile_n_idx
=
(
global_tile_n_offset
-
group_start_offset
)
/
size
<
1
>
(
epilogue_tiler
);
Tensor
tDgD_mn
=
gD_mn_view
(
_
,
_
,
_
,
scheduler
.
tile_m
(),
local_tile_n_idx
);
Tensor
tDgSFD_mn
=
gSFD_mn
(
_
,
_
,
scheduler
.
tile_m
(),
local_tile_n_idx
);
accumulator_pipeline
.
consumer_wait
(
accumulator_pipe_consumer_state
);
auto
Acc
=
bulk_tmem_epilogue
(
_
,
_
,
_
,
accumulator_pipe_consumer_state
.
index
());
Tensor
tDtAcc
=
thr_t2r
.
partition_S
(
Acc
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDgD
=
thr_t2r
.
partition_D
(
tDgD_mn
);
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tTR_rAcc
=
make_tensor
<
ElementAccumulator
>
(
shape
(
tDgD
));
// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor
tDrD
=
make_tensor
<
TD
>
(
shape
(
tDgD
));
Tensor
tTR_rAcc_frag
=
recast
<
cutlass
::
Array
<
ElementAccumulator
,
FragmentSize
>>
(
coalesce
(
tTR_rAcc
));
Tensor
tDrD_frag
=
recast
<
cutlass
::
Array
<
TD
,
FragmentSize
>>
(
coalesce
(
tDrD
));
Tensor
src
=
thr_r2g
.
retile_S
(
tDrD
);
Tensor
dst
=
thr_r2g
.
retile_D
(
tDgD
);
Tensor
tDgSFD_view
=
make_tensor
(
tDgSFD_mn
.
data
(),
make_layout
(
make_shape
(
shape
(
tDgSFD_mn
),
Int
<
1
>
{},
Int
<
1
>
{}),
make_stride
(
stride
(
tDgSFD_mn
),
Int
<
0
>
{},
Int
<
0
>
{})));
Tensor
tDgSFD
=
filter
(
thr_t2r
.
partition_D
(
tDgSFD_view
));
Tensor
tDrSFD
=
make_tensor
<
TSFD
>
(
shape
(
tDgSFD
));
static
int
constexpr
NumVecs
=
size
(
tDgD
)
/
VectorSize
;
Tensor
tD_rRowSFD_frg
=
recast
<
cutlass
::
Array
<
TSFD
,
NumVecs
>>
(
tDrSFD
);
// Compute amax and quantization scales for this tile
cutlass
::
maximum_absolute_value_reduction
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
,
true
>
amax_reduction
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
vec_maxs
;
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
pvscales
;
// Copy from TMEM to registers
copy
(
tiled_t2r
,
tDtAcc
,
tTR_rAcc
);
cutlass
::
arch
::
fence_view_async_tmem_load
();
accumulator_pipeline
.
consumer_release
(
accumulator_pipe_consumer_state
);
++
accumulator_pipe_consumer_state
;
if
constexpr
(
!
kUseFastMath
)
{
// Downcast to BF16 for bit-wise compatibility with
// unfused kernels
auto
convert_accum_to_bf16
=
cutlass
::
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
ElementAccumulator
,
FragmentSize
>
{};
auto
convert_bf16_to_accum
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
cutlass
::
bfloat16_t
,
FragmentSize
>
{};
tTR_rAcc_frag
(
_0
{})
=
convert_bf16_to_accum
(
convert_accum_to_bf16
(
tTR_rAcc_frag
(
_0
{})));
tTR_rAcc_frag
(
_1
{})
=
convert_bf16_to_accum
(
convert_accum_to_bf16
(
tTR_rAcc_frag
(
_1
{})));
}
auto
compute_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
*>
(
tTR_rAcc_frag
.
data
());
auto
output_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
TD
,
VectorSize
>
*>
(
tDrD_frag
.
data
());
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
vec_maxs
[
v
]
=
amax_reduction
(
ElementAccumulator
(
0
),
compute_frgs
[
v
]);
}
if
constexpr
(
kUseFastMath
)
{
// Fast math: multiply with precomputed reciprocal
pvscales
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
vec_maxs
,
global_encode_scale_multiplier
);
}
else
{
// Accurate math: perform division
pvscales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
vec_maxs
,
fp4_max
);
pvscales
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
pvscales
,
global_encode_scale
);
}
auto
pvscales_cvted
=
cutlass
::
NumericArrayConverter
<
TSFD
,
ElementAccumulator
,
NumVecs
>
{}(
pvscales
);
tD_rRowSFD_frg
(
_0
{})
=
pvscales_cvted
;
auto
qpvscale_ups
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
TSFD
,
NumVecs
>
{}(
tD_rRowSFD_frg
(
_0
{}));
auto
qpvscale_scaled
=
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
qpvscale_ups
,
global_decode_scale
);
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>
acc_scales
;
if
constexpr
(
kUseFastMath
)
{
// Fast math: compute approximate reciprocal
acc_scales
=
cutlass
::
reciprocal_approximate_ftz
<
decltype
(
qpvscale_scaled
)
>
{}(
qpvscale_scaled
);
}
else
{
// Accurate math: compute reciprocal with division
acc_scales
=
cutlass
::
divides
<
cutlass
::
Array
<
ElementAccumulator
,
NumVecs
>>
{}(
1.0
,
qpvscale_scaled
);
}
// Prepare stochastic rounding random state if enabled
uint4
random_uint4
=
uint4
{
0
,
0
,
0
,
0
};
transformer_engine
::
curanddx
::
detail
::
philox4x32_native_state
<
10
>
rng
;
// "Prefetch" a stochastic rounding state for the first tile
if
constexpr
(
kEnableStochasticRounding
)
{
const
size_t
rng_sequence
=
global_thread_idx
+
k_tile
*
512
+
scheduler
.
get_linear_tile_idx
()
*
K_TILE_MAX
*
512
;
rng
.
init
(
rng_seed
,
rng_sequence
,
rng_offset
);
}
CUTLASS_PRAGMA_UNROLL
// Apply round/quantize to each fragment, with or without stochastic rounding
for
(
int
v
=
0
;
v
<
NumVecs
;
v
++
)
{
auto
acc_scale
=
cutlass
::
minimum_with_nan_propagation
<
ElementAccumulator
>
{}(
acc_scales
[
v
],
cutlass
::
platform
::
numeric_limits
<
ElementAccumulator
>::
max
());
if
constexpr
(
kEnableStochasticRounding
)
{
random_uint4
=
rng
.
generate4
();
output_frgs
[
v
]
=
StochasticNumericConverter
(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
),
*
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
4
>
*>
(
&
random_uint4
));
}
else
{
output_frgs
[
v
]
=
cutlass
::
NumericArrayConverter
<
TD
,
ElementAccumulator
,
VectorSize
>
{}(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs
[
v
],
acc_scale
));
}
}
// Write quantized FP4 tile and dequant scale to gmem
copy
(
tiled_r2g
,
src
,
dst
);
copy
(
AutoVectorizingCopyWithAssumedAlignment
<
128
>
{},
tDrSFD
,
tDgSFD
);
}
scheduler
.
update_work_tile_info
();
}
while
(
scheduler
.
is_valid
());
}
}
else
if
(
is_epilogue_row_quant_warp
)
{
// Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage.
cutlass
::
arch
::
warpgroup_reg_alloc
<
136
>
();
if
constexpr
(
kEnableRowQuant
)
{
using
S2RVectorType
=
uint128_t
;
int
global_thread_idx
=
threadIdx
.
x
;
int
local_thread_idx
=
global_thread_idx
%
256
;
size_t
rng_seed
=
0
;
size_t
rng_offset
=
0
;
// g2s load all global_a_amax for all groups/tensors
CUTLASS_PRAGMA_NO_UNROLL
for
(
int
g
=
local_thread_idx
;
g
<
num_tensors
;
g
+=
NumEpilogueRowQuantThreadCount
)
{
shared_storage
.
global_a_amax
[
g
]
=
__ldg
(
reinterpret_cast
<
float
*>
(
amax_rowwise
+
g
));
}
// RNG for stochastic rounding
if
constexpr
(
kEnableStochasticRounding
)
{
rng_seed
=
rng_state
!=
nullptr
?
__ldg
(
rng_state
)
:
0
;
rng_offset
=
rng_state
!=
nullptr
?
__ldg
(
rng_state
+
1
)
:
0
;
}
// Input/output tensors/partitions for row quant warp
Tensor
mQA
=
make_tensor
(
cute
::
subbyte_iterator
<
TQA
>
(
QA
),
make_layout
(
make_shape
(
M
,
packed_N
),
dQA
));
Tensor
gQA_mn
=
local_tile
(
mQA
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
mSFA
=
make_tensor
(
make_gmem_ptr
(
SFA
),
sfa_layout
);
Tensor
gSFA_mn
=
local_tile
(
mSFA
,
epilogue_tiler
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_N)
// Swizzled shared memory A tile, with layout
Tensor
sA
=
as_position_independent_swizzle_tensor
(
group_modes
<
0
,
2
>
(
coalesce
(
make_tensor
(
make_smem_ptr
(
shared_storage
.
tensors
.
smem_A
.
data
()),
sAlayout
))));
// (BLOCK_M, BLOCK_M,PIPE)
// Set up layouts for partitioning – tile-by-warp, with vector granularity
using
S2RWarpLayout
=
Layout
<
Shape
<
_8
,
_4
>>
;
using
WarpGroupLayout
=
Layout
<
Shape
<
_1
,
_8
>>
;
using
S2RThreadLayout
=
decltype
(
blocked_product
(
S2RWarpLayout
{},
WarpGroupLayout
{}));
using
S2RValLayout
=
Layout
<
Shape
<
Int
<
VectorSize
>
,
_1
>>
;
using
S2RAtomA
=
Copy_Atom
<
AutoVectorizingCopy
,
TA
>
;
using
R2GAtomQA
=
Copy_Atom
<
AutoVectorizingCopy
,
TQA
>
;
using
R2GAtomSFA
=
Copy_Atom
<
AutoVectorizingCopy
,
TSFA
>
;
auto
tiled_s2r
=
make_tiled_copy
(
S2RAtomA
{},
S2RThreadLayout
{},
S2RValLayout
{});
auto
tiled_r2g_QA
=
make_tiled_copy
(
R2GAtomQA
{},
S2RThreadLayout
{},
S2RValLayout
{});
auto
tiled_r2g_SFA
=
make_tiled_copy
(
R2GAtomSFA
{},
S2RThreadLayout
{},
S2RValLayout
{});
auto
thr_s2r
=
tiled_s2r
.
get_slice
(
local_thread_idx
);
auto
thr_r2g_QA
=
tiled_r2g_QA
.
get_slice
(
local_thread_idx
);
auto
thr_r2g_SFA
=
tiled_r2g_SFA
.
get_slice
(
local_thread_idx
);
Tensor
tQAsA
=
thr_s2r
.
partition_S
(
sA
);
// (Copy, Copy_M, Copy_N, PIPE)
// Allocate temporary register tensors for copying quantization => output
Tensor
tQArA
=
make_tensor_like
<
TA
>
(
make_layout
(
tQAsA
(
_
,
_
,
_
,
_0
{}).
shape
()));
// (Copy, Copy_M, Copy_N)
Tensor
tQAgQA
=
thr_r2g_QA
.
partition_S
(
gQA_mn
);
Tensor
tQArQA
=
make_tensor_like
(
tQAgQA
(
_
,
_
,
_
,
_0
{},
_0
{}));
Tensor
tQAgSFA
=
thr_r2g_SFA
.
partition_S
(
gSFA_mn
);
Tensor
tQArSFA
=
make_tensor_like
(
tQAgSFA
(
_
,
_
,
_
,
_0
{},
_0
{}));
// Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8
// in order to go over the reserved named barrier count.
constexpr
int
row_quant_barrier_id
=
2
;
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumEpilogueRowQuantThreadCount
,
row_quant_barrier_id
);
int
group_idx
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
(
scheduler
.
tile_n_base
()
*
size
<
1
>
(
epilogue_tiler
))
*
M
,
packed_N
,
M
,
offsets
);
float
a_global_amax_val
=
shared_storage
.
global_a_amax
[
group_idx
];
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static
constexpr
float
fp4_max
=
6.0
f
;
static
constexpr
float
fp8_max
=
448.0
f
;
static
constexpr
float
fp4_max_inv
=
1.0
f
/
fp4_max
;
float
global_encode_scale
=
a_global_amax_val
>
0.0
f
?
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
(
fp8_max
*
fp4_max
)
/
a_global_amax_val
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
())
:
1.0
f
;
float
global_decode_scale
=
1.0
f
/
global_encode_scale
;
float
global_encode_scale_multiplier
=
1.0
f
;
if
constexpr
(
kUseFastMath
)
{
global_encode_scale_multiplier
=
global_encode_scale
*
fp4_max_inv
;
}
auto
sfa_converter
=
cutlass
::
NumericConverter
<
TSFA
,
ElementAccumulator
>
{};
do
{
CUTLASS_PRAGMA_NO_UNROLL
for
(
int
k_tile
=
0
;
k_tile
<
K_TILE_MAX
&&
k_tile
+
scheduler
.
tile_n_base
()
<
scheduler
.
tiles_n
();)
{
int
global_tile_n_offset
=
(
scheduler
.
tile_n_base
()
+
k_tile
)
*
size
<
1
>
(
epilogue_tiler
);
int
cur_group_idx
=
get_current_tensor_id
(
shape_rep
,
num_tensors
,
global_tile_n_offset
*
M
,
packed_N
,
M
,
offsets
);
if
(
cur_group_idx
!=
group_idx
)
{
group_idx
=
cur_group_idx
;
a_global_amax_val
=
shared_storage
.
global_a_amax
[
group_idx
];
// Update group quantization parameters/scaling
global_encode_scale
=
a_global_amax_val
>
0.0
f
?
cutlass
::
minimum_with_nan_propagation
<
float
>
{}(
(
fp8_max
*
fp4_max
)
/
a_global_amax_val
,
cutlass
::
platform
::
numeric_limits
<
float
>::
max
())
:
1.0
f
;
global_decode_scale
=
1.0
f
/
global_encode_scale
;
if
constexpr
(
kUseFastMath
)
{
global_encode_scale_multiplier
=
global_encode_scale
*
fp4_max_inv
;
}
}
auto
tQAgSFA_mn
=
tQAgSFA
(
_
,
_
,
_
,
scheduler
.
tile_m
(),
scheduler
.
tile_n_base
()
+
k_tile
);
auto
tQAgQA_mn
=
tQAgQA
(
_
,
_
,
_
,
scheduler
.
tile_m
(),
scheduler
.
tile_n_base
()
+
k_tile
);
auto
barrier_token
=
mainloop_pipeline
.
consumer_try_wait
(
mainloop_pipe_consumer_state
);
mainloop_pipeline
.
consumer_wait
(
mainloop_pipe_consumer_state
,
barrier_token
);
copy
(
tiled_s2r
,
tQAsA
(
_
,
_
,
_
,
mainloop_pipe_consumer_state
.
index
()),
tQArA
);
cutlass
::
arch
::
fence_view_async_shared
();
mainloop_pipeline
.
consumer_release
(
mainloop_pipe_consumer_state
);
++
mainloop_pipe_consumer_state
;
++
k_tile
;
// static int constexpr NumVecs = size(tQArA) / VectorSize;
cutlass
::
maximum_absolute_value_reduction
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>
,
true
>
amax_reduction
;
auto
compute_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
TA
,
VectorSize
>
*>
(
tQArA
.
data
());
auto
output_frgs
=
reinterpret_cast
<
cutlass
::
Array
<
TQA
,
VectorSize
>
*>
(
raw_pointer_cast
(
tQArQA
.
data
()));
Tensor
amax
=
make_tensor
<
ElementAccumulator
>
(
prepend
(
take
<
1
,
rank
(
tQArA
)
>
(
tQArA
.
shape
()),
_1
{}));
Tensor
pvscales
=
make_tensor_like
<
ElementAccumulator
>
(
amax
);
transformer_engine
::
curanddx
::
detail
::
philox4x32_native_state
<
10
>
rng
;
if
constexpr
(
kEnableStochasticRounding
)
{
const
size_t
rng_sequence
=
global_thread_idx
+
k_tile
*
512
+
scheduler
.
get_linear_tile_idx
()
*
K_TILE_MAX
*
512
+
tiles_in_m
*
tiles_in_n
*
K_TILE_MAX
*
512
;
rng
.
init
(
rng_seed
,
rng_sequence
,
rng_offset
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
size
<
1
>
(
group_modes
<
1
,
rank
(
tQArA
)
>
(
tQArA
));
v
++
)
{
auto
amax_view
=
group_modes
<
1
,
rank
(
amax
)
>
(
amax
);
auto
pvscales_view
=
group_modes
<
1
,
rank
(
pvscales
)
>
(
pvscales
);
auto
compute_frgs_up
=
cutlass
::
NumericArrayConverter
<
ElementAccumulator
,
TA
,
VectorSize
>
{}(
compute_frgs
[
v
]);
amax_view
(
_0
{},
v
)
=
amax_reduction
(
ElementAccumulator
(
0
),
compute_frgs_up
);
if
constexpr
(
kUseFastMath
)
{
// Fast math: multiply with precomputed reciprocal
pvscales_view
(
_0
{},
v
)
=
cutlass
::
multiplies
<
ElementAccumulator
>
{}(
amax_view
(
_0
{},
v
),
global_encode_scale_multiplier
);
}
else
{
// Accurate math: perform division
pvscales_view
(
_0
{},
v
)
=
cutlass
::
divides
<
ElementAccumulator
>
{}(
amax_view
(
_0
{},
v
),
fp4_max
);
pvscales_view
(
_0
{},
v
)
=
cutlass
::
multiplies
<
ElementAccumulator
>
{}(
pvscales_view
(
_0
{},
v
),
global_encode_scale
);
}
filter
(
tQArSFA
)(
v
)
=
sfa_converter
(
pvscales_view
(
_0
{},
v
));
auto
qpvscale_ups
=
cutlass
::
NumericConverter
<
ElementAccumulator
,
TSFA
>
{}(
filter
(
tQArSFA
)(
v
));
auto
qpvscale_scaled
=
cutlass
::
multiplies
<
ElementAccumulator
>
{}(
qpvscale_ups
,
global_decode_scale
);
ElementAccumulator
acc_scales
;
if
constexpr
(
kUseFastMath
)
{
// Fast math: compute approximate reciprocal
acc_scales
=
cutlass
::
reciprocal_approximate_ftz
<
decltype
(
qpvscale_scaled
)
>
{}(
qpvscale_scaled
);
}
else
{
// Accurate math: compute reciprocal with division
acc_scales
=
cutlass
::
divides
<
ElementAccumulator
>
{}(
1.0
,
qpvscale_scaled
);
}
auto
acc_scale
=
cutlass
::
minimum_with_nan_propagation
<
ElementAccumulator
>
{}(
acc_scales
,
cutlass
::
platform
::
numeric_limits
<
ElementAccumulator
>::
max
());
uint4
random_uint4
=
uint4
{
0
,
0
,
0
,
0
};
if
constexpr
(
kEnableStochasticRounding
)
{
random_uint4
=
rng
.
generate4
();
output_frgs
[
v
]
=
StochasticNumericConverter
(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs_up
,
acc_scale
),
*
reinterpret_cast
<
cutlass
::
Array
<
uint32_t
,
4
>
*>
(
&
random_uint4
));
}
else
{
output_frgs
[
v
]
=
cutlass
::
NumericArrayConverter
<
TQA
,
ElementAccumulator
,
VectorSize
>
{}(
cutlass
::
multiplies
<
cutlass
::
Array
<
ElementAccumulator
,
VectorSize
>>
{}(
compute_frgs_up
,
acc_scale
));
}
}
copy
(
tiled_r2g_QA
,
tQArQA
,
tQAgQA_mn
);
copy
(
tiled_r2g_SFA
,
filter
(
tQArSFA
),
filter
(
tQAgSFA_mn
));
}
// scheduler.advance();
scheduler
.
fetch_next_work
(
sched_pipeline
,
sched_pipeline_consumer_state
);
++
sched_pipeline_consumer_state
;
scheduler
.
update_work_tile_info
();
}
while
(
scheduler
.
is_valid
());
}
}
else
{
cutlass
::
arch
::
warpgroup_reg_dealloc
<
32
>
();
}
}
// NOLINT(readability/fn_size)
template
<
bool
kEnableStochasticRounding
,
bool
kEnableRHTColQuant
,
bool
kEnableRowQuant
,
bool
kEnableSwizzleSFOutput
,
class
TA
,
class
TB
,
class
TQA
,
class
TSFA
,
class
TD
=
TQA
,
class
TSFD
=
TSFA
,
bool
kUseFastMath
=
false
>
void
group_row_col_rht_gemm_ntt_w_sfc_graph_safe
(
int
packed_sequence_length
,
int
hidden_size
,
size_t
num_tensors
,
ShapeRepresentation
shape_rep
,
TA
const
*
A
,
TB
const
*
B
,
TQA
*
QA
,
TSFA
*
SFA
,
TQA
*
QA_COLWISE
,
TSFA
*
SFA_COLWISE
,
float
*
amax_rowwise
,
float
*
amax_colwise
,
const
int64_t
*
offsets
,
const
int64_t
*
first_dims
,
const
size_t
*
rng_state
,
uint32_t
*
tile_scheduler_workspace
,
uint32_t
sm_count
,
cudaStream_t
stream
,
int
k_tile_size
=
1024
)
{
using
namespace
cute
;
static
int
constexpr
SFVecSize
=
16
;
static
int
constexpr
RhtTensorSize
=
16
;
static_assert
(
RhtTensorSize
==
16
,
"RhtTensorSize must be 16"
);
using
LinearSFALayout
=
decltype
(
make_layout
(
make_shape
(
make_shape
(
Int
<
SFVecSize
>
{},
0
),
0
),
make_stride
(
make_stride
(
_0
{},
_1
{}),
0
)));
using
LinearSFDLayout
=
decltype
(
make_layout
(
make_shape
(
0
,
make_shape
(
Int
<
SFVecSize
>
{},
0
)),
make_stride
(
0
,
make_stride
(
_0
{},
_1
{}))));
using
SwizzledSFALayoutAtom
=
cutlass
::
detail
::
Sm1xxBlockScaledOutputConfig
<
SFVecSize
,
UMMA
::
Major
::
MN
>::
SfAtom
;
using
SwizzledSFDLayoutAtom
=
cutlass
::
detail
::
Sm1xxBlockScaledOutputConfig
<
SFVecSize
,
UMMA
::
Major
::
K
>::
SfAtom
;
using
SwizzledSFALayout
=
decltype
(
tile_to_shape
(
SwizzledSFALayoutAtom
{},
make_shape
(
hidden_size
,
packed_sequence_length
),
Step
<
_1
,
_2
>
{}));
using
SwizzledSFDLayout
=
decltype
(
tile_to_shape
(
SwizzledSFDLayoutAtom
{},
make_shape
(
hidden_size
,
packed_sequence_length
),
Step
<
_2
,
_1
>
{}));
using
SFALayout
=
cute
::
conditional_t
<
kEnableSwizzleSFOutput
,
SwizzledSFALayout
,
LinearSFALayout
>
;
using
SFDLayout
=
cute
::
conditional_t
<
kEnableSwizzleSFOutput
,
SwizzledSFDLayout
,
LinearSFDLayout
>
;
SFALayout
sfa_layout
;
SFDLayout
sfd_layout
;
if
constexpr
(
kEnableSwizzleSFOutput
)
{
sfa_layout
=
tile_to_shape
(
SwizzledSFALayoutAtom
{},
make_shape
(
hidden_size
,
packed_sequence_length
),
Step
<
_1
,
_2
>
{});
sfd_layout
=
tile_to_shape
(
SwizzledSFDLayoutAtom
{},
make_shape
(
hidden_size
,
packed_sequence_length
),
Step
<
_2
,
_1
>
{});
}
else
{
sfa_layout
=
make_layout
(
make_shape
(
make_shape
(
Int
<
SFVecSize
>
{},
hidden_size
/
SFVecSize
),
packed_sequence_length
),
make_stride
(
make_stride
(
_0
{},
_1
{}),
hidden_size
/
SFVecSize
));
sfd_layout
=
make_layout
(
make_shape
(
hidden_size
,
make_shape
(
Int
<
SFVecSize
>
{},
packed_sequence_length
/
SFVecSize
)),
make_stride
(
packed_sequence_length
/
SFVecSize
,
make_stride
(
_0
{},
_1
{})));
}
// Define shapes (dynamic)
auto
M
=
hidden_size
;
auto
N
=
packed_sequence_length
;
Tensor
tensorA
=
make_tensor
(
A
,
make_shape
(
hidden_size
,
packed_sequence_length
),
LayoutLeft
{});
Tensor
tensorB
=
make_tensor
(
B
,
make_shape
(
RhtTensorSize
,
RhtTensorSize
),
LayoutLeft
{});
Tensor
tensorQA
=
make_tensor
(
QA
,
make_shape
(
hidden_size
,
packed_sequence_length
),
LayoutLeft
{});
Tensor
tensorSFA
=
make_tensor
(
SFA
,
sfa_layout
);
// Define strides (from tensors)
auto
dA
=
stride
(
tensorA
);
// (dM,dK)
auto
dB
=
stride
(
tensorB
);
// (dN,dK)
auto
dD
=
LayoutRight
{};
// (dM,dN)
auto
dQA
=
stride
(
tensorQA
);
// (dM,dK)
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
auto
cluster_shape
=
ClusterShape
{};
auto
cluster_tile_shape
=
Shape
<
_128
,
Int
<
RhtTensorSize
>
,
Int
<
RhtTensorSize
>>
{};
auto
cluster_tile_mainloop
=
Shape
<
_128
,
Int
<
RhtTensorSize
>
,
_128
>
{};
// Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles
static
int
constexpr
EpilogueUnrollFactor
=
size
<
2
>
(
cluster_tile_mainloop
)
/
size
<
2
>
(
cluster_tile_shape
);
// Construct the MMA
auto
mma
=
make_tiled_mma
(
SM100_MMA_F16BF16_SS
<
TA
,
TB
,
float
,
size
<
0
>
(
cluster_tile_shape
),
size
<
1
>
(
cluster_tile_shape
),
UMMA
::
Major
::
MN
,
UMMA
::
Major
::
MN
>
{},
Layout
<
Shape
<
_1
,
_1
>>
{});
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V
(
size
(
cluster_shape
)
==
size
(
mma
));
CUTE_STATIC_ASSERT_V
(
evenly_divides
(
cluster_tile_shape
,
tile_shape
(
mma
)));
// Determine the A and B shapes
auto
mma_shape_B
=
partition_shape_B
(
mma
,
make_shape
(
size
<
1
>
(
cluster_tile_shape
),
size
<
2
>
(
cluster_tile_shape
)));
using
TiledMma
=
decltype
(
mma
);
using
AtomThrID
=
typename
TiledMma
::
AtomThrID
;
using
SmemShape_M
=
decltype
(
shape_div
(
shape
<
0
>
(
cluster_tile_shape
),
shape_div
(
shape
<
0
>
(
cluster_tile_shape
),
size
<
0
>
(
cluster_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_N
=
decltype
(
shape_div
(
shape
<
1
>
(
cluster_tile_shape
),
shape_div
(
shape
<
1
>
(
cluster_tile_shape
),
size
<
1
>
(
cluster_tile_shape
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K
=
decltype
(
cute
::
get
<
2
>
(
cluster_tile_shape
));
using
SmemLayoutAtomB
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TB
,
SmemShape_N
,
SmemShape_K
>
());
auto
mma_shape_A
=
partition_shape_A
(
mma
,
make_shape
(
size
<
0
>
(
cluster_tile_mainloop
),
size
<
2
>
(
cluster_tile_mainloop
)));
using
SmemShape_M_A
=
decltype
(
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
shape_div
(
shape
<
0
>
(
cluster_tile_mainloop
),
size
<
0
>
(
cluster_tile_mainloop
)
/
size
(
AtomThrID
{}))));
using
SmemShape_K_A
=
decltype
(
cute
::
get
<
2
>
(
cluster_tile_mainloop
));
using
SmemLayoutAtomA
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
sm100_smem_selector
<
cute
::
UMMA
::
Major
::
MN
,
TA
,
SmemShape_M_A
,
SmemShape_K_A
>
());
static
uint32_t
constexpr
TotalTmemRows
=
128
;
static
uint32_t
constexpr
Sm100TmemCapacityColumns
=
512
;
static
uint32_t
constexpr
TotalTmem
=
TotalTmemRows
*
Sm100TmemCapacityColumns
;
static
uint32_t
constexpr
AccumulatorPipelineStageCount
=
TotalTmem
/
(
cute
::
size
<
0
>
(
cluster_tile_shape
)
*
cute
::
size
<
1
>
(
cluster_tile_shape
));
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr
int
SchedulerPipelineStageCount
=
4
;
static
int
constexpr
MainloopPipelineBytes
=
sizeof
(
typename
cutlass
::
detail
::
CustomizedPipelineTmaUmmaAsync
<
1
,
Shape
<
_1
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>>::
SharedStorage
);
static
int
constexpr
SchedulerWorkspaceBytes
=
sizeof
(
int
)
*
SchedulerPipelineStageCount
;
static
int
constexpr
SchedulerThrottlePipelineBytes
=
sizeof
(
typename
cutlass
::
PipelineAsync
<
SchedulerPipelineStageCount
>::
SharedStorage
);
static
int
constexpr
SchedulerPipelineBytes
=
sizeof
(
typename
cutlass
::
PipelineCLCFetchAsync
<
SchedulerPipelineStageCount
,
ClusterShape
>::
SharedStorage
);
static
int
constexpr
TmemDeallocBytes
=
sizeof
(
cutlass
::
arch
::
ClusterBarrier
);
static
int
constexpr
BTensorBytes
=
cute
::
size
(
mma_shape_B
)
*
sizeof
(
TB
);
static
int
constexpr
AccPipelineBytes
=
sizeof
(
typename
cutlass
::
PipelineUmmaAsync
<
AccumulatorPipelineStageCount
/
EpilogueUnrollFactor
,
Shape
<
_1
,
_1
,
_1
>>::
SharedStorage
);
static
int
constexpr
TmemBasePtrsBytes
=
sizeof
(
uint32_t
);
static
int
constexpr
kBlackwellSmemSize
=
232448
;
// 232KB in bytes
static
int
constexpr
kBytesPerStage
=
cute
::
size
(
mma_shape_A
)
*
sizeof
(
TA
)
+
MainloopPipelineBytes
;
static
int
constexpr
kReservedBytes
=
SchedulerWorkspaceBytes
+
SchedulerThrottlePipelineBytes
+
SchedulerPipelineBytes
+
TmemBasePtrsBytes
+
TmemDeallocBytes
+
BTensorBytes
+
AccPipelineBytes
;
// Reserve for barriers and other uses
static
int
constexpr
kMaxStages
=
(
kBlackwellSmemSize
-
kReservedBytes
)
/
kBytesPerStage
;
auto
sP
=
Int
<
kMaxStages
>
{};
// SMEM pipelines
auto
sA
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomA
{},
append
(
mma_shape_A
,
sP
),
Step
<
_2
,
_1
,
_3
>
{});
// (MMA,MMA_M,MMA_K,PIPE)
auto
sB
=
UMMA
::
tile_to_mma_shape
(
SmemLayoutAtomB
{},
append
(
mma_shape_B
,
_1
{}));
// (MMA,MMA_N,MMA_K, _1)
auto
sD
=
Layout
<
_1
>
{};
// XXX Dummy
auto
tma_load_a
=
make_tma_copy_A_sm100
(
SM90_TMA_LOAD
{},
tensorA
,
sA
(
_
,
_
,
_
,
0
),
cluster_tile_mainloop
,
mma
);
auto
tma_load_b
=
make_tma_copy_B_sm100
(
SM90_TMA_LOAD
{},
tensorB
,
sB
(
_
,
_
,
_
,
0
),
cluster_tile_shape
,
mma
);
// Assert checks on tile sizes -- no predication
assert
(
M
%
size
<
0
>
(
cluster_tile_shape
)
==
0
);
assert
(
N
%
size
<
1
>
(
cluster_tile_shape
)
==
0
);
dim3
dimBlock
(
512
);
dim3
dimCluster
(
size
<
0
>
(
cluster_shape
),
size
<
1
>
(
cluster_shape
),
size
<
2
>
(
cluster_shape
));
dim3
dimGrid
(
sm_count
,
1
,
1
);
int
smem_size
=
sizeof
(
SharedStorage
<
TA
,
TB
,
decltype
(
sA
),
decltype
(
sB
),
ClusterShape
,
AccumulatorPipelineStageCount
,
EpilogueUnrollFactor
,
SchedulerPipelineStageCount
>
);
auto
*
kernel_ptr
=
&
group_row_col_rht_gemm_device_graph_safe
<
decltype
(
M
),
decltype
(
N
),
decltype
(
k_tile_size
),
decltype
(
cluster_shape
),
decltype
(
cluster_tile_shape
),
TA
,
decltype
(
dA
),
decltype
(
sA
),
decltype
(
tma_load_a
),
TB
,
decltype
(
dB
),
decltype
(
sB
),
decltype
(
tma_load_b
),
TD
,
decltype
(
dD
),
decltype
(
sD
),
TSFD
,
decltype
(
sfd_layout
),
TQA
,
decltype
(
dQA
),
TSFA
,
decltype
(
sfa_layout
),
decltype
(
mma
),
AccumulatorPipelineStageCount
,
SchedulerPipelineStageCount
,
kEnableStochasticRounding
,
kEnableRHTColQuant
,
kEnableRowQuant
,
kEnableSwizzleSFOutput
,
kUseFastMath
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
*
kernel_ptr
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
// Set workspace and set to zero
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
reinterpret_cast
<
void
*>
(
tile_scheduler_workspace
),
0
,
sizeof
(
uint32_t
),
stream
));
// Launch kernel
cutlass
::
ClusterLaunchParams
params
=
{
dimGrid
,
dimBlock
,
dimCluster
,
smem_size
,
stream
};
cutlass
::
Status
status
=
cutlass
::
launch_kernel_on_cluster
(
params
,
(
void
const
*
)
kernel_ptr
,
M
,
N
,
k_tile_size
,
cluster_shape
,
cluster_tile_shape
,
A
,
dA
,
sA
,
tma_load_a
,
B
,
dB
,
sB
,
tma_load_b
,
QA
,
dQA
,
SFA
,
sfa_layout
,
QA_COLWISE
,
SFA_COLWISE
,
amax_rowwise
,
amax_colwise
,
offsets
,
first_dims
,
num_tensors
,
shape_rep
,
tile_scheduler_workspace
,
mma
,
rng_state
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Kernel launch failed."
);
}
}
// namespace
}
// namespace detail
void
group_hadamard_transform_cast_fusion_graph_safe
(
const
GroupedTensor
*
input
,
GroupedTensor
*
output
,
const
Tensor
&
hadamard_matrix_
,
QuantizationConfig
&
quant_config
,
Tensor
&
quant_workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
group_hadamard_transform_cast_fusion_graph_safe
);
using
transformer_engine
::
detail
::
kMaxTensorsPerKernel
;
using
transformer_engine
::
detail
::
ShapeRepresentation
;
void
*
input_base_ptr
=
reinterpret_cast
<
void
*>
(
input
->
data
.
dptr
);
// TODO(zhongbo): add input sanity checks here
bool
all_has_row_quant
=
output
->
has_data
();
bool
all_has_col_quant
=
output
->
has_columnwise_data
();
// Stochastic rounding config
const
bool
use_stochastic_rounding
=
quant_config
.
stochastic_rounding
;
const
size_t
*
rng_state
=
nullptr
;
if
(
use_stochastic_rounding
)
{
NVTE_CHECK
(
quant_config
.
rng_state
!=
nullptr
,
"Enabled stochastic rounding without providing RNG state"
);
const
Tensor
&
rng_state_tensor
=
*
convertNVTETensorCheck
(
quant_config
.
rng_state
);
NVTE_CHECK
(
rng_state_tensor
.
dtype
()
==
DType
::
kInt64
,
"RNG state should contain 2 64-bit values."
);
NVTE_CHECK
(
rng_state_tensor
.
data
.
shape
==
std
::
vector
<
size_t
>
{
2
},
"Shape of the RNG state should be [2], but got "
,
rng_state_tensor
.
data
.
shape
);
rng_state
=
reinterpret_cast
<
const
size_t
*>
(
rng_state_tensor
.
data
.
dptr
);
}
uint32_t
*
tile_scheduler_workspace
=
nullptr
;
NVTE_CHECK
(
quant_workspace
.
data
.
dptr
!=
nullptr
,
"Quantization workspace must be provided."
);
NVTE_CHECK
(
quant_workspace
.
data
.
buffer_size_bytes
()
>=
sizeof
(
uint32_t
),
"Quantization workspace must be at least 4 bytes."
);
tile_scheduler_workspace
=
reinterpret_cast
<
uint32_t
*>
(
quant_workspace
.
data
.
dptr
);
// Template arguments
using
TA
=
cute
::
bfloat16_t
;
using
TB
=
cute
::
bfloat16_t
;
using
TD
=
cutlass
::
float_e2m1_t
;
using
TSFD
=
cutlass
::
float_ue4m3_t
;
using
TQA
=
TD
;
using
TSFA
=
TSFD
;
checkCuDriverContext
(
stream
);
// Check Hadamard matrix
constexpr
int
kHadamardDimension
=
16
;
NVTE_CHECK
(
hadamard_matrix_
.
dtype
()
==
transformer_engine
::
DType
::
kBFloat16
,
"Hadamard matrix must be BF16 tensor, but dtype is "
,
to_string
(
hadamard_matrix_
.
dtype
()),
"."
);
const
SimpleTensor
&
hadamard_matrix
=
hadamard_matrix_
.
data
;
NVTE_CHECK
(
(
hadamard_matrix_
.
shape
()
==
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
}),
"Hadamard matrix must have shape="
,
std
::
vector
<
size_t
>
{
kHadamardDimension
,
kHadamardDimension
},
", but got shape="
,
hadamard_matrix_
.
shape
(),
"."
);
const
size_t
hadamard_dimension
=
hadamard_matrix
.
shape
[
0
];
const
size_t
num_tensors
=
input
->
num_tensors
;
const
size_t
first_logical_dim
=
input
->
logical_shape
.
data
[
0
];
const
size_t
last_logical_dim
=
input
->
logical_shape
.
data
[
1
];
// const size_t elts_total = first_logical_dim * last_logical_dim;
NVTE_CHECK
(
first_logical_dim
%
128
==
0
,
"First dimension of a grouped tensor should be divisible by 128."
);
NVTE_CHECK
(
last_logical_dim
%
128
==
0
,
"Last dimension of a grouped tensor should be divisible by 128."
);
NVTE_CHECK
(
num_tensors
<=
kMaxTensorsPerKernel
,
"Number of tensors should be less than or equal to "
,
kMaxTensorsPerKernel
);
ShapeRepresentation
shape_rep
=
ShapeRepresentation
::
VARYING_FIRST_DIM
;
if
(
output
->
all_same_shape
())
{
shape_rep
=
ShapeRepresentation
::
SAME_BOTH_DIMS
;
}
else
if
(
output
->
all_same_first_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_LAST_DIM
;
}
else
if
(
output
->
all_same_last_dim
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_FIRST_DIM
;
}
else
if
(
output
->
varying_both_dims
())
{
shape_rep
=
ShapeRepresentation
::
VARYING_BOTH_DIMS
;
}
TQA
*
const
rowwise_data_base_ptr
=
reinterpret_cast
<
TQA
*>
(
output
->
data
.
dptr
);
TSFA
*
const
rowwise_scale_inv_base_ptr
=
reinterpret_cast
<
TSFA
*>
(
output
->
scale_inv
.
dptr
);
TQA
*
const
colwise_data_base_ptr
=
reinterpret_cast
<
TQA
*>
(
output
->
columnwise_data
.
dptr
);
TSFA
*
const
colwise_scale_inv_base_ptr
=
reinterpret_cast
<
TSFA
*>
(
output
->
columnwise_scale_inv
.
dptr
);
float
*
const
amax_rowwise_base_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
float
*
const
amax_colwise_base_ptr
=
reinterpret_cast
<
float
*>
(
output
->
columnwise_amax
.
dptr
);
const
int64_t
*
const
offsets_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
tensor_offsets
.
dptr
);
const
int64_t
*
const
first_dims_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
->
first_dims
.
dptr
);
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
const
bool
is_const_last_dim
=
(
shape_rep
==
ShapeRepresentation
::
SAME_BOTH_DIMS
||
shape_rep
==
ShapeRepresentation
::
VARYING_FIRST_DIM
);
NVTE_CHECK
(
is_const_last_dim
,
"Currently we only support const last dimension for graph safe hadamard transform."
);
auto
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
int
k_tile_size
=
1024
;
const
bool
use_swizzle_sf_output
=
false
;
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_stochastic_rounding
,
kEnableStochasticRounding
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
all_has_col_quant
,
kEnableRhtColQuant
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
all_has_row_quant
,
kEnableRowQuant
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
use_swizzle_sf_output
,
kEnableSwizzleSFOutput
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
quant_config
.
use_fast_math
,
kUseFastMath
,
if
constexpr
(
kEnableRhtColQuant
||
kEnableRowQuant
)
{
detail
::
group_row_col_rht_gemm_ntt_w_sfc_graph_safe
<
kEnableStochasticRounding
,
kEnableRhtColQuant
,
kEnableRowQuant
,
kEnableSwizzleSFOutput
,
TA
,
TB
,
TQA
,
TSFA
,
TD
,
TSFD
,
kUseFastMath
>
(
/*packed_sequence_length=*/
first_logical_dim
,
/*hidden_size=*/
last_logical_dim
,
/*num_tensors=*/
num_tensors
,
/*shape_rep=*/
shape_rep
,
/*A=*/
reinterpret_cast
<
TA
const
*>
(
input_base_ptr
),
/*B=*/
reinterpret_cast
<
TB
const
*>
(
hadamard_matrix
.
dptr
),
/*QA=*/
reinterpret_cast
<
TQA
*>
(
rowwise_data_base_ptr
),
/*SFA=*/
reinterpret_cast
<
TSFA
*>
(
rowwise_scale_inv_base_ptr
),
/*QA_COLWISE=*/
reinterpret_cast
<
TQA
*>
(
colwise_data_base_ptr
),
/*SFA_COLWISE=*/
reinterpret_cast
<
TSFA
*>
(
colwise_scale_inv_base_ptr
),
/*amax_rowwise=*/
reinterpret_cast
<
float
*>
(
amax_rowwise_base_ptr
),
/*amax_colwise=*/
reinterpret_cast
<
float
*>
(
amax_colwise_base_ptr
),
/*offsets=*/
offsets_ptr
,
/*first_dims=*/
first_dims_ptr
,
/*rng_state=*/
rng_state
,
/*tile_scheduler_workspace=*/
tile_scheduler_workspace
,
/*sm_count=*/
sm_count
,
/*stream=*/
stream
,
/*k_tile_size=*/
k_tile_size
);
}
else
{
NVTE_ERROR
(
"Invalid kernel configuration (kEnableRHTColQuant="
,
kEnableRhtColQuant
,
", kEnableRowQuant="
,
kEnableRowQuant
,
")."
);
}
);););););
}
}
// namespace transformer_engine
void
nvte_group_hadamard_transform_cast_fusion_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_hadamard_transform_cast_fusion_graph_safe
);
using
namespace
transformer_engine
;
GroupedTensor
*
input_tensor
=
convertNVTEGroupedTensorCheck
(
input
);
GroupedTensor
*
output_tensor
=
convertNVTEGroupedTensorCheck
(
output
);
Tensor
*
quant_workspace_tensor
=
convertNVTETensorCheck
(
quant_workspace
);
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
if
(
input_tensor
->
num_tensors
==
0
)
{
return
;
}
// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion_graph_safe
(
input_tensor
,
output_tensor
,
*
convertNVTETensorCheck
(
hadamard_matrix
),
quant_config_cpp
,
*
quant_workspace_tensor
,
stream
);
}
transformer_engine/common/include/transformer_engine/activation.h
View file @
9df0c4a3
...
...
@@ -31,6 +31,7 @@ extern "C" {
enum
class
NVTE_Activation_Type
{
GELU
,
GEGLU
,
GLU
,
SILU
,
SWIGLU
,
RELU
,
...
...
@@ -52,6 +53,16 @@ enum class NVTE_Activation_Type {
*/
void
nvte_gelu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_gelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -62,6 +73,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void
nvte_silu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the SiLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_silu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -72,6 +93,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void
nvte_relu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_relu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -82,6 +113,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void
nvte_qgelu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Quick GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_qgelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -92,6 +133,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void
nvte_srelu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Squared ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_srelu
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -104,6 +155,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void
nvte_dgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_dgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -116,6 +179,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_dsilu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the SiLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_dsilu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -128,6 +203,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_drelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_drelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -140,6 +227,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_dqgelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Quick GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_dqgelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
@@ -152,6 +251,44 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void
nvte_dsrelu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the Squared ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_dsrelu
(
const
NVTEGroupedTensor
grad
,
const
NVTEGroupedTensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the GLU (Gated Linear Unit) activation of the input.
* GLU(a,b) = sigmoid(a) * b
* See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083)
* and "GLU Variants Improve Transformer" (arXiv:2002.05202).
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes sigmoid(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_glu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the GLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_dglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
...
...
transformer_engine/common/include/transformer_engine/cast.h
View file @
9df0c4a3
...
...
@@ -89,6 +89,17 @@ extern "C" {
*/
void
nvte_quantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Casts input grouped tensor to MXFP8.
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor. See file level comments.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* The type of quantized tensor in the output depends on the scaling mode of the output
...
...
@@ -132,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
void
nvte_quantize_dbias
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
...
...
@@ -155,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias_dgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
act_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
...
...
@@ -178,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias_dsilu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
act_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
...
...
@@ -201,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias_drelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
act_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
...
...
@@ -224,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias_dqgelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
act_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
...
...
@@ -247,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_quantize_dbias_dsrelu
(
const
NVTEGroupedTensor
input
,
const
NVTEGroupedTensor
act_input
,
NVTEGroupedTensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment