Unverified Commit eb8e792b authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform (#2411)



* rowwise colwise RHT group quant v1
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove local array RW
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* change wait_barrier
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fast math options
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* use mult to replace div
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* format
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* bulk move random states
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* greptile
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert to use divides
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* avoid fp32 bf16 round-trip in RHT cast fusion
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* integrate row col rht fusion, functional
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* numerics aligned
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* style
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove device sync
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* 128 padding
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert colwise rng state creation because of row-col fused kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix CI, linter
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* refactor RS for generating two random values
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid invalid configs with templated kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix acc pipeline init with 0 arrival count
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* restore rowwise-only mode
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* switch to dynamic atomic scheduler
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid instantiating group RHT+cast kernel without row-wise or col-wise output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include fast math option in quantization config
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings and review nits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TE license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug where kernel is always launched on stream
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore BF16 intermediate downcast in fused RHT-cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix numerical test of grouped kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Make sure row-wise and col-wise quantization use different RNG seeds
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Restore autoformatter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 47902e96
......@@ -53,7 +53,7 @@ ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
--set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
"""
......@@ -173,7 +173,9 @@ def benchmark_linear(
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
def run_benchmark_linear(
mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
......@@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}")
print(f"fwd_only: {fwd_only}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
......@@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
mode="fwd_only" if fwd_only else "fwd_bwd",
num_gemms=num_gemms,
)
......@@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
]
)
timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"
df = pd.DataFrame(
data=data,
columns=[
......@@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
timing_notation,
],
)
......@@ -234,7 +238,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
"--output-dir",
type=str,
default="benchmark_output/",
help="output path for report",
......@@ -266,6 +270,12 @@ if __name__ == "__main__":
default=2048,
help="Output dimension to use, default is 2048",
)
parser.add_argument(
"--fwd-only",
action="store_true",
default=False,
help="Run forward pass only, default is both forward and backward passes",
)
args = parser.parse_args()
jagged_input_splits = None
......@@ -297,7 +307,7 @@ if __name__ == "__main__":
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_list = [65536]
token_dim_list = [16384, 32768, 65536, 98304]
hidden_dim_list = [7168]
output_dim_list = [2048]
......@@ -371,7 +381,8 @@ if __name__ == "__main__":
recipe_name,
use_bias,
num_gemms=num_gemms,
m_splits=jagged_input_splits,
m_splits_provided=jagged_input_splits,
fwd_only=args.fwd_only,
)
df_linears = pd.concat([df_linears, df])
......
......@@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference(
for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
......@@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference(
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
......@@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference(
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 8192),
(16384, 16384),
],
)
......
......@@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
......
......@@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
......@@ -19,6 +19,7 @@
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
......@@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");
// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_transpose_nvfp4.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_
#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace group_quantize_transpose_kernel {
using namespace quantization_and_transposition_SF;
using namespace core;
using namespace ptx;
#if FP4_TYPE_SUPPORTED
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed
struct MultiAmaxCastTransposeFusionArgs {
// Amax buffer for rowwise scaling
void *rowwise_amax_list[kMaxTensorsPerKernel];
// Rowwise scale pointers with 128x4 padding included for rowwise scaling
void *output_rowwise_scale_inv_list[kMaxTensorsPerKernel];
// (Unused for rowwise only scaling) Amax buffer for colwise scaling
void *colwise_amax_list[kMaxTensorsPerKernel];
// (Unused for rowwise only scaling) output data pointers for fp4 transposed output
void *output_colwise_data_list[kMaxTensorsPerKernel];
// (Unused for rowwise only scaling) output scale inverse pointers for each tensor
void *output_colwise_scale_inv_list[kMaxTensorsPerKernel];
// (Unused for rowwise only scaling) output scale stride for colwise scaling
int output_colwise_scale_stride[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of split_sections of each tensor of input
int split_sections_range[kMaxTensorsPerKernel + 1];
// Number of tensors (splits) being processed by kernel
int num_tensors;
};
__device__ __forceinline__ int GetTensorId(MultiAmaxCastTransposeFusionArgs *kernel_args_ptr,
int offset) {
// check the kernel args and get the corresponding id
int tensor_id = 0;
while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) {
++tensor_id;
}
return tensor_id;
}
// Helper to get tensor id at offset, and also whether [offset_start, offset_end) crosses a split boundary.
__device__ __forceinline__ int GetTensorIdAndBoundary(
MultiAmaxCastTransposeFusionArgs *kernel_args_ptr, int offset_start, int offset_end,
bool *cross_boundary) {
int tensor_id_start = 0;
while (kernel_args_ptr->split_sections_range[tensor_id_start + 1] <= offset_start) {
++tensor_id_start;
}
int tensor_id_end = tensor_id_start;
if (offset_end != offset_start) {
if (kernel_args_ptr->split_sections_range[tensor_id_start + 1] < offset_end) {
tensor_id_end = tensor_id_start + 1;
}
}
if (cross_boundary) {
*cross_boundary = (tensor_id_start != tensor_id_end);
}
return tensor_id_start;
}
__device__ __forceinline__ void UpdateEncodeDecodeScaleFP32(float *amax_ptr, float *s_enc_ptr,
float *s_dec_ptr) {
float s_env_value =
(amax_ptr == nullptr) ? 1.0f : compute_global_encode_scaling_factor_FP4(*amax_ptr);
float s_dec_value = 1.0 / s_env_value;
*s_enc_ptr = s_env_value;
*s_dec_ptr = s_dec_value;
return;
}
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
// Each call generates 4x uint32_t random numbers
constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4;
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
group_quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
nvfp4_scale_t *const scales_ptr, const float *noop,
const size_t rows, const size_t cols,
const size_t scale_stride, const size_t *rng_state,
MultiAmaxCastTransposeFusionArgs kernel_args) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
int rnd_idx = 0;
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
// TODO(zhongbo): add back when transpose is supported
// const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
// const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
// TODO(zhongbo): add back when transpose is supported
// const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
// const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
// TODO(zhongbo): add back when transpose is supported
// const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
// const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
// TODO(zhongbo): add back when transpose is supported
// const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// TODO (zhongbo): finish this
float *amax_rowwise_ptr = nullptr;
float *amax_colwise_ptr = nullptr;
nvfp4_scale_t *split_rowwise_scale_ptr = nullptr;
// suppose the amax is fixed for the current 128x128 tile (need 128 padding)
bool need_update_tensor_id = true;
int tensor_id = GetTensorIdAndBoundary(&kernel_args, block_offset_Y, block_offset_Y + CHUNK_DIM_Y,
&need_update_tensor_id);
size_t split_start = kernel_args.split_sections_range[tensor_id];
size_t split_end = kernel_args.split_sections_range[tensor_id + 1];
amax_rowwise_ptr = reinterpret_cast<float *>(kernel_args.rowwise_amax_list[tensor_id]);
split_rowwise_scale_ptr =
reinterpret_cast<nvfp4_scale_t *>(kernel_args.output_rowwise_scale_inv_list[tensor_id]);
float S_enc_rowwise = 1.0f;
float S_dec_rowwise = 1.0f;
UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise);
// TODO (zhongbo): colwise scaling disabled for now because of transpose
float S_enc_colwise = 1.0f;
float S_dec_colwise = 1.0f;
if (amax_colwise_ptr != nullptr) {
UpdateEncodeDecodeScaleFP32(amax_colwise_ptr, &S_enc_colwise, &S_dec_colwise);
} else {
S_enc_colwise = S_enc_rowwise;
S_dec_colwise = S_dec_rowwise;
}
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
// for stages from 1 to STAGES - 1, we need to update the tensor id
// skip updating tensor id if it's the last CTA, and some stages will be out of bounds
if (need_update_tensor_id && stage > 0 && (block_offset_Y + stage_offset_Y < rows)) {
int new_tensor_id = GetTensorId(&kernel_args, block_offset_Y + stage_offset_Y);
if (new_tensor_id != tensor_id) {
tensor_id = new_tensor_id;
split_start = kernel_args.split_sections_range[tensor_id];
split_end = kernel_args.split_sections_range[tensor_id + 1];
amax_rowwise_ptr = reinterpret_cast<float *>(kernel_args.rowwise_amax_list[tensor_id]);
UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise);
split_rowwise_scale_ptr =
reinterpret_cast<nvfp4_scale_t *>(kernel_args.output_rowwise_scale_inv_list[tensor_id]);
// TODO (zhongbo): colwise scaling disabled for now because of transpose
// Skip fetching colwise amax pointer and scaling factor updates
}
}
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE) < chunk_rows;
// TODO(zhongbo): depending on input padding multiple (whether 128 or 64), use either scale_ptr or split_rowwise_scale_ptr
// const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
// scales_ptr[scale_idx_global] = S_dec_b_fp8;
// }
// Map to local split coordinates
const size_t split_rows = split_end - split_start;
const size_t local_scale_row = scales_offset_Y - split_start;
// Local bounds: 0 <= local_scale_row < split_rows
const bool local_rowwise_scale_is_within_bounds_Y = local_scale_row < split_rows;
// Index inside this split’s scale buffer
const size_t scale_idx_local = local_scale_row * scale_stride + scales_offset_X;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y &&
local_rowwise_scale_is_within_bounds_Y) {
split_rowwise_scale_ptr[scale_idx_local] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
// TODO(zhongbo): add back when transpose is supported
// const size_t global_offset_Y_t = block_offset_Y_t;
// const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
// TODO(zhongbo): add back when transpose is supported
// if constexpr (RETURN_TRANSPOSE) {
// ptx::cp_async_bulk_tensor_2d_shared_to_global(
// reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
// global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
// }
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// TODO(zhongbo): add back when transpose is supported
// Vectorized store scaling factors through SHMEM
// if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
// using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
// const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
// ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
// const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
// const size_t count = // number of scales in Y dimension of this chunk
// (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
// nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
// constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
// if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// // Fast path: vectorized store when destination is properly aligned
// scales_vec.store_to(dst);
// } else {
// // Safe path: element-wise store for tails or unaligned destinations
// scales_vec.store_to_elts(dst, 0, count);
// }
// }
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif // FP4_TYPE_SUPPORTED
} // namespace group_quantize_transpose_kernel
template <bool use_2d_quantization>
void group_quantize_transpose(const Tensor &input, const Tensor *noop,
std::vector<Tensor *> &output_list, const size_t *split_sections,
size_t num_tensors, const QuantizationConfig *quant_config,
cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace group_quantize_transpose_kernel;
using namespace ptx;
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
NVTE_CHECK(num_tensors == output_list.size(),
"Number of output tensors should match number of tensors.");
NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel,
"Number of tensors should be less than or equal to ", kMaxTensorsPerKernel);
Tensor *output = nullptr;
// loop over the list to find the first non-empty tensor
for (size_t i = 0; i < num_tensors; ++i) {
if (output_list[i]->has_data()) {
output = output_list[i];
break;
}
}
NVTE_CHECK(output != nullptr, "No output tensor found.");
// also check that the output has not null data pointer
NVTE_CHECK(output->data.dptr != nullptr, "Output data pointer is null.");
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
bool return_transpose = output->has_columnwise_data();
// forbid return transpose for now because group quantize transpose is not supported yet
NVTE_CHECK(!return_transpose, "Return transpose is not supported for group quantize transpose.");
// output_List is contiguous in memory, so take the first tensor as the contiguous output
auto output_contiguous = output->data;
constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
// process the output list and produce the multi-tensor args for grouped kernel
MultiAmaxCastTransposeFusionArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.split_sections_range[0] = 0;
for (size_t i = 0; i < num_tensors; ++i) {
if (split_sections[i] == 0) {
continue;
}
kernel_args.rowwise_amax_list[kernel_args.num_tensors] =
reinterpret_cast<void *>(output_list[i]->amax.dptr);
kernel_args.output_rowwise_scale_inv_list[kernel_args.num_tensors] =
reinterpret_cast<void *>(output_list[i]->scale_inv.dptr);
// kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i];
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
// check overflow
NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0,
"split_sections_range overflow the int32_t");
kernel_args.num_tensors++;
}
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
// Note (zhongbo): for group quantize of [x1, x2, ..., xn]
// for the rowwise sclaing, scaling factor stride is shared between all tensors
// for the colwise scaling, scaling factor stride is different for each tensor because of transpose
// since transpose puts token dimension splits in the last dimension of the tensor
const size_t scale_stride = output->scale_inv.shape[1];
// const size_t scale_stride_transpose =
// return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
// alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output_contiguous, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, 4);
// if (return_transpose) {
// create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
// BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
// }
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel =
group_quantize_transpose_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
NVTE_ERROR("2D quantization is not supported for group quantize transpose.");
}
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(tensor_map_input, tensor_map_output,
scales_ptr, noop_ptr, rows, cols,
scale_stride, rng_state, kernel_args);
NVTE_CHECK_CUDA(cudaGetLastError());
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_TRANSPOSE_NVFP4_CUH_
......@@ -394,6 +394,7 @@ struct QuantizationConfig {
NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
bool use_fast_math = false;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
......@@ -402,7 +403,8 @@ struct QuantizationConfig {
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding
sizeof(bool), // stochastic_rounding
sizeof(bool) // use_fast_math
};
};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
namespace detail {
// Producer-consumer pipeline implementation
// for UMMA producer. In this case, UMMA barrier arrives are used
// by producer_commit. Use case, accumulator generation as
// the result of MMA instructions.
template <int Stages_, class ClusterShape = Shape<int, int, _1>,
class AtomThrShape_MNK_ = Shape<_1, _1, _1> >
class CustomizedPipelineTmaUmmaAsync {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaAsync<Stages>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params,
ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1
if (cute::size(cluster_shape) > 1) {
multicast_consumer_arrival_count =
((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) *
params.num_consumers;
}
CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 &&
"Multicast consumer arrival count must be non-zero");
CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero");
cutlass::arch::detail::initialize_barrier_array_pair_aligned<
decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt,
multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape,
dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
} else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
// Constructor by default initializes barriers and calculates masks.
// These operations can be explicity deferred by specifying InitBarriers and InitMasks.
// If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called.
template <typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE CustomizedPipelineTmaUmmaAsync(SharedStorage& storage, Params params,
ClusterShape cluster_shape, InitBarriers = {},
InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, InitMasks{}),
params_(params),
empty_barrier_ptr_(&storage.empty_barrier_[0]),
full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> ||
cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> ||
cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
////////////////////
// Producer APIs
////////////////////
// Four member functions are always used in pairs:
//
// * producer_try_acquire and producer_acquire, and
// * consumer_try_wait and consumer_wait.
//
// The two functions with "try" in their names are called "try" functions,
// and the other two are conceptually "finalize" functions.
// The "try" function in each pair starts the process of waiting on the barrier to flip.
// It opportunistically waits for an implementation-dependent timeout.
// Whether or not the barrier has flipped yet, the try function will return a token.
// If the token indicates that the barrier has not flipped,
// then the token must be passed into the corresponding "finalize" function.
// The finalize function will then block until the barrier has flipped.
// If the token indicates that the barrier _has_ flipped,
// then it is still correct to pass it into the finalize function.
// The finalize function will return immediately in that case.
CUTLASS_DEVICE
ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) {
return impl_.producer_try_acquire(state, skip_wait);
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state,
ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
impl_.producer_expect_transaction(state, transaction_bytes);
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(PipelineState state, uint32_t bytes) { impl_.producer_commit(state, bytes); }
// Prevents early exit of producer blocks in Cluster.
// This should be called once before kernel exits.
CUTLASS_DEVICE
void producer_tail(PipelineState state) { impl_.producer_tail(state); }
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
////////////////////
// Consumer APIs
////////////////////
CUTLASS_DEVICE
ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) {
return impl_.consumer_try_wait(state, skip_wait);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state,
ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void umma_consumer_release(PipelineState state) { umma_consumer_release(state.index(), false); }
CUTLASS_DEVICE
void consumer_release(PipelineState state) { impl_.consumer_release(state); }
private:
Impl impl_;
Params params_;
EmptyBarrier* empty_barrier_ptr_;
FullBarrier* full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notified.
CUTLASS_DEVICE
void umma_consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
// {$nv-release-never begin}
// TODO: Needs to be updated once Blackwell specialized pipeline is implemented.
// XMMA style bar_peek will be tested. We will need to revisit skip interface and
// what skip means when we have bar_peek functionality.
// A separate MR will implement MMA_2x1SM specialized pipeline.
// {$nv-release-never end}
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
} else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> && size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
} else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
} // namespace detail
} // namespace cutlass
#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
......@@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector<Tensor*>& o
}
// Multi zero out multiple amaxes if needed
// Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// Currently don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block
dim3 block_setup_amax(kMaxTensorsPerKernel);
dim3 grid_setup_amax(1);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/print_error.hpp"
namespace transformer_engine {
namespace detail {
namespace {
using namespace cute;
using cute::
Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
using Stride2D = cute::Stride<int, cute::Int<1>>;
constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB, expand 64 if needed
struct MultiAmaxHadamardCastFusionArgs {
// (output) Amax buffer for pre-RHT amax buffer
void *global_amax_list[kMaxTensorsPerKernel];
// output C pointers for each tensor
void *output_colwise_list[kMaxTensorsPerKernel];
// output scale inverse pointers for each tensor
void *output_colwise_scale_inv_list[kMaxTensorsPerKernel];
// split sections of each tensor of input
int split_sections[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of split_sections of each tensor of input
int split_sections_range[kMaxTensorsPerKernel + 1];
// stride 2D struct for CUTE
Stride2D output_stride2d_list[kMaxTensorsPerKernel];
// Number of tensors (splits) being processed by kernel
int num_tensors;
};
__device__ __forceinline__ float *GetGlobalAmaxPtrByTensorId(
MultiAmaxHadamardCastFusionArgs *kernel_args_ptr, int tensor_id) {
// directly returns the global amax pointer by tensor id
if (tensor_id < 0 || tensor_id >= kernel_args_ptr->num_tensors) {
return nullptr;
}
return reinterpret_cast<float *>(kernel_args_ptr->global_amax_list[tensor_id]);
}
__device__ __forceinline__ int GetTensorId(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr,
int offset) {
// Check the kernel args and get the corresponding id
const int num_tensors = kernel_args_ptr->num_tensors;
if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) {
return num_tensors - 1;
}
int tensor_id = 0;
while (kernel_args_ptr->split_sections_range[tensor_id + 1] <= offset) {
++tensor_id;
}
return tensor_id;
}
// calculate the global encode scale factor for a given global amax.
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float kFP8E4M3Max = 448.0f;
constexpr float kFP4E2M1Max = 6.0f;
// If scale is infinity, return max value of float32
float global_encode_scale = cutlass::minimum_with_nan_propagation<float>{}(
kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits<float>::max());
// If global amax is 0 or infinity, return 1
return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale;
}
template <class ElementA, class ElementB, class ASmemLayout, class BSmemLayout>
struct SharedStorage {
static constexpr int AccumulatorPipelineStageCount = 16;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline =
cutlass::PipelineTmaUmmaAsync<MainloopPipelineStageCount, Shape<_1, _1, _1>, AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
uint32_t tmem_base_ptr;
struct TensorStorage : cute::aligned_struct<128, _1> {
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
};
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8> StochasticNumericConverterBase(
cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n"
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n"
"}"
: "=h"(output_ptr[0]), "=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]),
"f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1]));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16> StochasticNumericConverter(
cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const *rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr =
reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr =
reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr =
reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class MShape, class NShape, class KShape, class ClusterTileShape, class TA, class AStride,
class ASmemLayout, class TmaLoadA, class TB, class BStride, class BSmemLayout,
class TmaLoadB, class TC, class CStride, class CSmemLayout, class TSFC, class TiledMMA,
bool kEnableStochasticRounding = false, bool kUseFastMath = false>
__global__ static void group_rht_gemm_device(
MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, TA const *A, AStride dA,
ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, TB const *B, BStride dB,
BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, CSmemLayout, TiledMMA mma,
MultiAmaxHadamardCastFusionArgs kernel_args, const size_t *rng_state) {
using namespace cute;
using X = Underscore;
// static constexpr bool kApplyStochasticRounding = true;
using ElementAccumulator = float;
static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static constexpr uint32_t kTmaTransactionBytes = cutlass::bits_to_bytes(
size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr int kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v<TB>);
static constexpr int AccumulatorPipelineStageCount = 16;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline =
cutlass::PipelineTmaUmmaAsync<MainloopPipelineStageCount, Shape<_1, _1, _1>, AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using TmemAllocator = cute::TMEM::Allocator1Sm;
static constexpr int VectorSize = 16;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<CSmemLayout>::value);
// Represent the full tensors
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(16, 16));
using TensorC = decltype(make_tensor(subbyte_iterator<TC>(recast_ptr<TC>(nullptr)), // engine
make_shape(int{}, int{}), // (M, N_i)
Stride2D{} // stride (dM, dN)
));
using TensorSFC = decltype(make_tensor(
make_gmem_ptr(recast_ptr<TSFC>(nullptr)),
make_layout(make_shape(int{}, // M
make_shape(make_shape(Int<16>{}, _4{}), // (16, 4)
int{}) // n_tiles = split / 64
),
make_stride(int{}, // dM = (split / 16)
make_stride(make_stride(_0{}, _1{}), // inner (16,4) layout
_4{}) // tiles stride
))));
auto cluster_shape = Shape<_1, _1, _1>{};
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
const int K_TILE_MAX = min(N, K) / 64;
uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile);
uint32_t tiles_in_n = (N + 64 - 1) / 64;
uint32_t linear_tile_idx = blockIdx.x;
uint32_t tile_idx_m = linear_tile_idx % tiles_in_m;
uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
auto mainloop_tiler = Shape<_128, _16, _64>{};
auto epilogue_tiler = Shape<_128, _64, _64>{};
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor gB_nk =
local_tile(mB, cluster_tile, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
// Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
using TensorGC = decltype(local_tile(std::declval<TensorC>(), decltype(epilogue_tiler){},
make_coord(_, _, _), Step<_1, _1, X>{}));
using TensorGSFC = decltype(local_tile(std::declval<TensorSFC>(), decltype(epilogue_tiler){},
make_coord(_, _, _), Step<_1, _1, X>{}));
// Allocate SMEM
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()),
sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
auto mma_epilogue = make_tiled_mma(
SM100_MMA_F16BF16_SS<TA, TB, ElementAccumulator, 128, 64, UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1, _1>>{});
ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster);
using TiledMmaEpilogue = decltype(mma_epilogue);
Tensor tCgA = thr_mma.partition_A(gA_mk);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0, 2>(ClusterTileShape{}));
auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0, 2>(epilogue_tiler));
auto bulk_tmem_mma =
TiledMMA::make_fragment_C(append(acc_shape_mma, Int<AccumulatorPipelineStageCount>{}));
auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(
append(acc_shape_epilogue, Int<AccumulatorPipelineStageCount / 4>{}));
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(
32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] =
tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA));
auto [tBgB, tBsB] =
tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
int warp_idx = cutlass::canonical_warp_idx_sync();
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7);
// if (is_epilogue_warp && elect_one_sync()) {
// // prefetch to make the global amax in cache
// for (size_t i = 0; i < kernel_args.num_tensors; ++i) {
// cute::prefetch(raw_pointer_cast(kernel_args.global_amax_list[i]));
// }
// }
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params,
cluster_shape, cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state =
cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
if (is_dma_warp) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0],
kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0),
tBsB(_, 0));
}
do {
bool is_first_wave = linear_tile_idx == blockIdx.x;
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_, tile_idx_m, _);
int k_tile = 0;
auto barrier_token =
mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
CUTE_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) {
int k_tile_idx_n = tile_idx_n + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token =
mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n),
tAsA(_, write_stage));
}
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n;) {
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_, _, _, read_stage);
auto tCrB_nk = tCrB(_, _, 0, 0);
CUTE_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) {
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTE_UNROLL
for (int i = 0; i < 4; i++) {
auto accumulators =
bulk_tmem_mma(_, _, _, accumulator_pipe_producer_state.index() * 4 + i);
gemm(mma, tCrA_mk(_, _, k_block * 4 + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} else if (is_epilogue_warp) {
static constexpr int FragmentSize = 256 / sizeof_bits_v<TC>;
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int thread_idx = threadIdx.x % 128;
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{}));
auto tiled_r2g =
make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TC>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(thread_idx);
auto thr_r2g = tiled_r2g.get_slice(thread_idx);
// NVFP4 non-E8 recipe constants and global scales
static constexpr float fp4_max = 6.0f;
static constexpr float fp4_max_inv = 1.0f / fp4_max;
// get global amax pointer
int tensor_id = GetTensorId(&kernel_args, tile_idx_n * 64);
float *global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, tensor_id);
TC *cur_output_colwise_ptr = reinterpret_cast<TC *>(kernel_args.output_colwise_list[tensor_id]);
TSFC *cur_output_colwise_scale_inv_ptr =
reinterpret_cast<TSFC *>(kernel_args.output_colwise_scale_inv_list[tensor_id]);
int cur_output_colwise_n = kernel_args.split_sections[tensor_id];
TensorC cur_mC =
cute::make_tensor(cute::subbyte_iterator<TC>(cur_output_colwise_ptr),
cute::make_shape(static_cast<int>(M), cur_output_colwise_n), // (M, N_i)
kernel_args.output_stride2d_list[tensor_id]);
auto cur_sfc_shape =
make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64));
auto cur_sfc_stride =
make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{}));
TensorSFC cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr),
make_layout(cur_sfc_shape, cur_sfc_stride));
TensorGC cur_gC_mn =
local_tile(cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N)
);
TensorGSFC cur_gSFC_mn = local_tile(
cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N-like)
);
Tensor tCgC = thr_mma_epilogue.partition_C(cur_gC_mn);
float global_amax_val = *global_amax_ptr;
float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
// get the starting index of current k-tile in global tensor, to query the correct global amax
int cur_k_tile_global_elem_idx = (tile_idx_n + k_tile) * 64;
int new_tensor_id = GetTensorId(&kernel_args, cur_k_tile_global_elem_idx);
// float* new_global_amax_ptr = GetGlobalAmaxPtr(&kernel_args, cur_k_tile_global_elem_idx);
global_amax_ptr = GetGlobalAmaxPtrByTensorId(&kernel_args, new_tensor_id);
// update the scaling factors when it's no longer the same amax pointer
// TODO(zhongbo): the math operations are very expensive
// since the kernel is persistent, we can have a cache for all the possible scaling factors
if (tensor_id != new_tensor_id) {
global_amax_val = *global_amax_ptr;
global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
global_decode_scale = 1.0f / global_encode_scale;
tensor_id = new_tensor_id;
// went through the cute operations to update the local tensors
cur_output_colwise_ptr =
reinterpret_cast<TC *>(kernel_args.output_colwise_list[tensor_id]);
cur_output_colwise_scale_inv_ptr =
reinterpret_cast<TSFC *>(kernel_args.output_colwise_scale_inv_list[tensor_id]);
cur_output_colwise_n = kernel_args.split_sections[tensor_id];
cur_mC = cute::make_tensor(
cute::subbyte_iterator<TC>(cur_output_colwise_ptr),
cute::make_shape(static_cast<int>(M), cur_output_colwise_n), // (M, N_i)
kernel_args.output_stride2d_list[tensor_id]);
cur_sfc_shape =
make_shape(M, make_shape(make_shape(Int<16>{}, _4{}), cur_output_colwise_n / 64));
cur_sfc_stride =
make_stride(cur_output_colwise_n / 16, make_stride(make_stride(_0{}, _1{}), _4{}));
cur_mSFC = cute::make_tensor(make_gmem_ptr(cur_output_colwise_scale_inv_ptr),
make_layout(cur_sfc_shape, cur_sfc_stride));
cur_gC_mn = local_tile(
cur_mC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{} // (BLK_M, BLK_N)
);
cur_gSFC_mn = local_tile(cur_mSFC, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}
// (BLK_M, BLK_N-like)
);
tCgC = thr_mma_epilogue.partition_C(cur_gC_mn);
}
// maybe udpated to the new tensor id
int tensor_start_elem = kernel_args.split_sections_range[tensor_id];
int local_tile_idx_n = (cur_k_tile_global_elem_idx - tensor_start_elem) / 64;
Tensor tCgC_mn = tCgC(_, _, _, tile_idx_m, local_tile_idx_n);
Tensor tCgSFC_mn = cur_gSFC_mn(_, _, tile_idx_m, local_tile_idx_n);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto tCtC = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index());
Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc =
make_tensor<ElementAccumulator>(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrC = make_tensor<TC>(shape(tDgC));
Tensor tTR_rAcc_frag =
recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrC_frag = recast<cutlass::Array<TC, FragmentSize>>(coalesce(tDrC));
Tensor src = thr_r2g.retile_S(tDrC);
Tensor dst = thr_r2g.retile_D(tDgC);
Tensor tCgSFC = make_tensor(
tCgSFC_mn.data(), make_layout(make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{})));
Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC));
Tensor tDrSFC = make_tensor<TSFC>(shape(tDgSFC));
static constexpr int NumVecs = size(tDgC) / VectorSize;
Tensor tC_rRowSFD_frg = recast<cutlass::Array<TSFC, NumVecs>>(tDrSFC);
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>,
true>
amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// TMEM_LOAD
copy(tiled_t2r, tDtC, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
if constexpr (!kUseFastMath) {
// Downcast to BF16 for bit-wise compatibility with unfused
// kernels
auto convert_accum_to_bf16 =
cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator,
FragmentSize>{};
auto convert_bf16_to_accum =
cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t,
FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array<ElementAccumulator, VectorSize> *>(
tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array<TC, VectorSize> *>(tDrC_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales =
cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
pvscales, global_encode_scale);
}
auto pvscales_cvted =
cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(
tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
qpvscale_ups, global_decode_scale);
cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales =
cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales =
cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
}
// Initialize RNG for tile
const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = uint4{0, 0, 0, 0};
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(
acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale),
reinterpret_cast<cutlass::Array<uint32_t, 4> *>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TC, ElementAccumulator, VectorSize>{}(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale));
}
}
copy(tiled_r2g, src, dst);
// copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrC, tDgC);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC,
bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void group_rht_gemm_ntt_w_sfc(int m, int n, TA const *A, TB const *B,
MultiAmaxHadamardCastFusionArgs *kernel_args_ptr,
const size_t *rng_state, uint32_t sm_count, cudaStream_t stream,
int k_tile_size = 2048) {
using namespace cute;
// Define shapes (dynamic)
auto M = static_cast<int>(m);
auto N = static_cast<int>(n);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, m); // (dM,dK)
auto dB = make_stride(Int<1>{}, 16); // (dN,dK)
for (size_t i = 0; i < kernel_args_ptr->num_tensors; ++i) {
kernel_args_ptr->output_stride2d_list[i] =
make_stride(kernel_args_ptr->split_sections[i], Int<1>{});
}
auto cga_shape = Shape<_1, _1, _1>{};
auto cga_tile_shape = Shape<_128, _16, _16>{};
auto cluster_tile_mainloop = Shape<_128, _16, _64>{};
// Construct the MMA
auto mma = make_tiled_mma(
SM100_MMA_F16BF16_SS<TA, TB, float, 128, 16, UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1, _1>>{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B =
partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(
shape<0>(cga_tile_shape),
shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(
shape<1>(cga_tile_shape),
shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cga_tile_shape));
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::sm100_smem_selector<cute::UMMA::Major::MN, TB,
SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(
mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A =
decltype(shape_div(shape<0>(cluster_tile_mainloop),
shape_div(shape<0>(cluster_tile_mainloop),
size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes
constexpr int kBytesPerStage =
cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB);
constexpr int kReservedBytes = 256; // Reserve for barriers and other uses
constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{},
append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{},
append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE)
auto sC = Layout<_1>{}; // XXX Dummy
// Create GMEM tensors
Tensor tensorA = make_tensor(A, make_layout(make_shape(M, N), dA)); // (M,N)
Tensor tensorB = make_tensor(B, make_layout(make_shape(16, 16), dB)); // (16,16)
// Create the TiledCopy
auto tma_load_a =
make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma);
auto tma_load_b =
make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cga_tile_shape, mma);
// Assert checks on tile sizes -- no predication
NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, "Inner dimension must be divisible by ",
static_cast<size_t>(size<0>(cga_tile_shape)), " but got ", M, ".");
NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, "Outer dimension must be divisible by ",
4 * static_cast<size_t>(size<1>(cga_tile_shape)), " but got ", N, ".");
uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size));
tiles = (tiles < sm_count) ? tiles : sm_count;
dim3 dimBlock(256);
dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape));
dim3 dimGrid(tiles, 1, 1);
int smem_size = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
auto *kernel_ptr = &group_rht_gemm_device<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), TA, decltype(dA),
decltype(sA), decltype(tma_load_a), TB, decltype(dB), decltype(sB), decltype(tma_load_b), TC,
Stride2D, decltype(sC), TSFC, decltype(mma), kEnableStochasticRounding, kUseFastMath>;
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
(*kernel_ptr)<<<dimGrid, dimBlock, smem_size, stream>>>(M, N, k_tile_size, cga_tile_shape, A, dA,
sA, tma_load_a, B, dB, sB, tma_load_b, sC,
mma, *kernel_args_ptr, rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
}
// this function is used to wrap the group_rht_gemm_ntt_w_sfc function
// to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC,
bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void group_rht_gemm_ttt_wrapper(int m, int n, TA const *A, TB const *B,
MultiAmaxHadamardCastFusionArgs *kernel_args_ptr,
const size_t *rng_state, uint32_t sm_count, cudaStream_t stream,
int k_tile_size = 1024) {
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
group_rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding, kUseFastMath>(
n, m, A, B, kernel_args_ptr, rng_state, sm_count, stream, k_tile_size);
}
} // namespace
} // namespace detail
void group_hadamard_transform_cast_fusion_columnwise(
const Tensor &input_, std::vector<Tensor *> &output_list, const size_t *split_sections,
size_t num_tensors, const Tensor &hadamard_matrix_, QuantizationConfig &quant_config,
cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_cast_fusion_columnwise);
using transformer_engine::detail::kMaxTensorsPerKernel;
using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs;
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor &input = input_.data;
NVTE_CHECK(output_list.size() == num_tensors,
"Number of output tensors should match number of tensors.");
NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel,
"Number of tensors should be less than or equal to ", kMaxTensorsPerKernel);
// construct the multi-tensor args
MultiAmaxHadamardCastFusionArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.split_sections_range[0] = 0;
for (size_t i = 0; i < num_tensors; ++i) {
NVTE_CHECK(split_sections[i] % 64 == 0, "component ", i,
" of split_sections should be 64 multiple");
if (split_sections[i] == 0) {
continue;
}
kernel_args.global_amax_list[kernel_args.num_tensors] =
reinterpret_cast<void *>(output_list[i]->amax.dptr);
// TODO(zhongbo): should we change API assumption to use columnwise_data instead of data?
kernel_args.output_colwise_list[kernel_args.num_tensors] =
reinterpret_cast<void *>(output_list[i]->data.dptr);
kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] =
reinterpret_cast<void *>(output_list[i]->scale_inv.dptr);
kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i];
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
kernel_args.num_tensors++;
}
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (quant_config.rng_state != nullptr) {
Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TC = cutlass::float_e2m1_t;
using TSFC = cutlass::float_ue4m3_t;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t ndim = input.shape.size();
const size_t n = input.shape[ndim - 1];
size_t m = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
m *= input.shape[i];
}
auto sm_count = transformer_engine::cuda::sm_count();
NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");
int k_tile_size = 1024;
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
detail::group_rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding,
kUseFastMath>(
/*m=*/m, /*n=*/n, /*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*kernel_args_ptr=*/&kernel_args, /*rng_state=*/rng_state, /*sm_count=*/sm_count,
/*stream=*/stream, /*k_tile_size=*/k_tile_size);););
}
} // namespace transformer_engine
void nvte_group_hadamard_transform_cast_fusion_columnwise(
const NVTETensor input, NVTETensor *outputs, const NVTETensor hadamard_matrix,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_hadamard_transform_cast_fusion_columnwise);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_list(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
output_list[i] = convertNVTETensorCheck(outputs[i]);
}
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion_columnwise(
*input_tensor, output_list, split_sections, num_tensors,
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "customized_pipeline.cuh"
#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/float8.h"
#include "cutlass/float_subbyte.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/platform/platform.h"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/print_error.hpp"
namespace transformer_engine {
namespace detail {
namespace {
using namespace cute;
// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor
using cute::Tensor;
constexpr int kMaxTensorsPerKernel = 64;
struct MultiAmaxHadamardCastFusionArgs {
// (output) Amax buffer for input A amax buffer
void *global_a_amax_list[kMaxTensorsPerKernel];
// (output) Amax buffer for pre-RHT amax buffer
void *global_d_amax_list[kMaxTensorsPerKernel];
// output D pointers for each tensor
void *output_colwise_list[kMaxTensorsPerKernel];
// output SFD inverse pointers for each tensor
void *output_colwise_scale_inv_list[kMaxTensorsPerKernel];
// split sections of each tensor of input
int split_sections[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of split_sections of each tensor of input
int split_sections_range[kMaxTensorsPerKernel + 1];
// Number of tensors (splits) being processed by kernel
int num_tensors;
};
__device__ __forceinline__ int GetGroupIdx(MultiAmaxHadamardCastFusionArgs *kernel_args_ptr,
int offset) {
// Check the kernel args and get the corresponding id
const int num_tensors = kernel_args_ptr->num_tensors;
if (offset >= kernel_args_ptr->split_sections_range[num_tensors]) {
return num_tensors - 1;
}
int group_idx = 0;
while (kernel_args_ptr->split_sections_range[group_idx + 1] <= offset) {
++group_idx;
}
return group_idx;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8> StochasticNumericConverterBase(
cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n"
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n"
"}"
: "=h"(output_ptr[0]), "=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]),
"f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1]));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16> StochasticNumericConverter(
cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr =
reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr =
reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr =
reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(&rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class ElementA, class ElementB, class ASmemLayout, class BSmemLayout, class ClusterShape,
int AccumulatorPipelineStageCount_, int EpilogueUnrollFactor_,
int SchedulerPipelineStageCount_>
struct SharedStorage {
static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline =
cutlass::detail::CustomizedPipelineTmaUmmaAsync<MainloopPipelineStageCount, Shape<_1, _1, _1>,
AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
using SchedPipeline = cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount_, ClusterShape>;
using SchedPipelineStorage = typename SchedPipeline::SharedStorage;
using SchedThrottlePipeline = cutlass::PipelineAsync<SchedulerPipelineStageCount_>;
using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage;
struct TensorStorage : cute::aligned_struct<128, _1> {
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
alignas(16) SchedPipelineStorage sched;
alignas(16) SchedThrottlePipelineStorage sched_throttle;
alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_];
alignas(16) float global_a_amax[kMaxTensorsPerKernel];
alignas(16) float global_d_amax[kMaxTensorsPerKernel];
uint32_t atomic_tile_counter[SchedulerPipelineStageCount_];
uint32_t tmem_base_ptr;
};
// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support
template <class MShape, class NShape, class KShape, class ClusterShape, class ClusterTileShape,
class TA, class AStride, class ASmemLayout, class TmaLoadA, class TB, class BStride,
class BSmemLayout, class TmaLoadB, class TD, class DStride, class DSmemLayout, class TSFD,
class TSFDLayout, class TQA, class QAStride, class TSFA, class TSFALayout, class TiledMMA,
int AccumulatorPipelineStageCount_, int SchedulerPipelineStageCount_,
bool kEnableStochasticRounding_ = false, bool kEnableRHTColQuant_ = true,
bool kEnableRowQuant_ = true, bool kEnableSwizzleSFOutput_ = false,
bool kUseFastMath_ = false>
__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device(
MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile,
TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a,
TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b,
TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, MultiAmaxHadamardCastFusionArgs args,
uint32_t *tile_scheduler_workspace, TiledMMA mma, const size_t *rng_state) {
using namespace cute;
// Abort immediately if compilation is not supported
constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY;
if constexpr (!is_blackwell_arch) {
NVTE_DEVICE_ERROR(
"group_row_col_rht_gemm_device is only supported on Blackwell "
"with architecture-specific compilation. "
"Try recompiling with sm_100a or similar.");
return;
}
static_assert(kEnableRHTColQuant_ || kEnableRowQuant_,
"group_row_col_rht_gemm_device must generate row-wise "
"and/or column-wise output.");
#if !defined(CUTLASS_ARCH_CLC_ENABLED)
CUTLASS_NOT_IMPLEMENTED();
return;
#endif
using X = Underscore;
// Accumulator data type for main computation
using ElementAccumulator = float;
static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes(
size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_;
static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_;
static constexpr bool kEnableRowQuant = kEnableRowQuant_;
static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_;
static constexpr bool kUseFastMath = kUseFastMath_;
// Constant for RHT tensor processing (tile size etc)
static int constexpr RhtTensorSize = 16;
// Transaction bytes for TMA transfer on RHT tensor blocks
static int constexpr kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v<TB>);
static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
// Mainloop pipeline stage calculation, vectorization parameters for scaling factors
static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{});
static int constexpr SFVecSize = 16;
// Swizzle output layout for scaling factor arrays
using SwizzledSFALayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::MN>::SfAtom;
using SwizzledSFDLayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::K>::SfAtom;
// Mainloop pipeline types for TMA async execution and epilogue cluster scheduling
using MainloopPipeline =
cutlass::detail::CustomizedPipelineTmaUmmaAsync<MainloopPipelineStageCount, ClusterShape,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using SchedPipeline = cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape>;
using SchedPipelineState = typename SchedPipeline::PipelineState;
using SchedThrottlePipeline = cutlass::PipelineAsync<SchedulerPipelineStageCount>;
using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState;
static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>");
using TmemAllocator = cute::TMEM::Allocator1Sm;
static int constexpr VectorSize = RhtTensorSize;
// Compile-time safety: static shapes required for shared memory layouts
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
// CUTE_STATIC_ASSERT(is_static<DSmemLayout>::value);
auto cluster_size = size<0>(cluster_shape);
auto mainloop_tiler = Shape<_128, _16, _128>{};
auto epilogue_tiler = Shape<_128, _128, _128>{};
static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile);
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler);
struct TileScheduler {
uint32_t tiles_in_m = 0;
uint32_t tiles_in_n = 0;
uint32_t linear_idx = 0;
uint32_t next_linear_idx = 0;
uint32_t start_idx = 0;
uint32_t tile_m_idx = 0;
uint32_t tile_n_idx = 0;
int k_tile_max = 0;
uint32_t *atomic_tile_index_;
uint32_t *smem_tile_counter;
uint32_t atomic_offset;
cutlass::FastDivmodU64 divmod_tiles_in_m;
CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax,
uint32_t *atomic_tile_index, uint32_t *smem_tile_counter)
: tiles_in_m(tiles_m),
tiles_in_n(tiles_n),
linear_idx(blockIdx.x),
next_linear_idx(blockIdx.x),
start_idx(blockIdx.x),
k_tile_max(kmax),
atomic_tile_index_(atomic_tile_index),
smem_tile_counter(smem_tile_counter),
atomic_offset(gridDim.x),
divmod_tiles_in_m(uint64_t(tiles_m)) {
update_tile_idx();
}
CUTLASS_DEVICE void update_tile_idx() {
uint64_t q, r;
divmod_tiles_in_m(q, r, uint64_t(linear_idx));
tile_m_idx = static_cast<uint32_t>(r);
tile_n_idx = static_cast<uint32_t>(q) * uint32_t(k_tile_max);
}
CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; }
CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; }
CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; }
CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; }
CUTLASS_DEVICE bool is_valid() const {
return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()),
cute::make_coord(tiles_in_m, tiles_in_n));
}
CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; }
CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; }
// Fetch a new tile_id using atomics.
CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) {
uint32_t tile_id_counter = 0;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %2, 1;\n\t"
"@p atom.global.add.u32 %0, [%1], 1; \n\t"
"}"
: "=r"(tile_id_counter)
: "l"(atomic_tile_index_), "r"(pred));
return tile_id_counter;
}
CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline,
SchedPipelineState sched_pipeline_consumer_state) {
sched_pipeline.consumer_wait(sched_pipeline_consumer_state);
next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()];
cutlass::arch::fence_view_async_shared();
sched_pipeline.consumer_release(sched_pipeline_consumer_state);
return;
}
CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline,
SchedPipelineState sched_pipeline_producer_state) {
uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state);
// Wait for clcID buffer to become empty with a flipped phase
sched_pipeline.producer_acquire(sched_pipeline_producer_state);
auto is_leading_thread = cute::elect_one_sync();
uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset;
uint32_t smem_addr =
cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]);
if (is_leading_thread) {
cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0);
}
++sched_pipeline_producer_state;
return sched_pipeline_producer_state;
}
CUTLASS_DEVICE auto update_work_tile_info() {
linear_idx = next_linear_idx;
update_tile_idx();
return;
}
};
// Allocate and alias shared memory to the kernel's shared storage type
extern __shared__ char shared_memory[];
using SharedStorage =
SharedStorage<TA, TB, ASmemLayout, BSmemLayout, ClusterShape, AccumulatorPipelineStageCount,
EpilogueUnrollFactor, SchedulerPipelineStageCount>;
SharedStorage &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
// Compute the number of tiles in M and N after tiling and assign scheduler
uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile))));
uint32_t tiles_in_n = uint32_t(
size(ceil_div(args.split_sections_range[args.num_tensors], size<2>(epilogue_tiler))));
TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace,
shared_storage.atomic_tile_counter);
int block_rank_in_cluster = cute::block_rank_in_cluster();
// Shapes for accumulated tiles in mainloop and epilogue
auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{});
auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{});
// Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended
auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int<AccumulatorPipelineStageCount>{});
auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape);
// Number of threads assigned for various epilogue roles depending on quantization settings
static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0;
static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0;
static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0;
static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0;
static int constexpr NumSchedThreads = 32;
static int constexpr NumMainloopLoadThreads = 32;
static int constexpr NumEpilogueThreads =
NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount;
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(
NumMmaThreadCount + NumEpilogueColQuantThreadCount,
cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
int warp_idx = cutlass::canonical_warp_idx_sync();
// warp assignment
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_sched_warp = (warp_idx == 2);
bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7);
bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15);
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params,
cluster_shape, cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
using AccumulatorPipelineInitBarriers = cute::bool_constant<kEnableRHTColQuant>;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state =
cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_col_quant_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count =
size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params,
cluster_shape, AccumulatorPipelineInitBarriers{},
cute::true_type{}); // Delay mask calculation
typename SchedPipeline::Params sched_pipeline_params;
if (is_sched_warp) {
sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer;
} else {
sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer;
}
sched_pipeline_params.producer_blockid = 0;
sched_pipeline_params.producer_arv_count = 1;
sched_pipeline_params.consumer_arv_count =
NumSchedThreads +
cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount);
sched_pipeline_params.transaction_bytes = sizeof(uint32_t);
sched_pipeline_params.initializing_warp = 3;
SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape);
SchedPipelineState sched_pipeline_consumer_state;
SchedPipelineState sched_pipeline_producer_state =
cutlass::make_producer_start_state<SchedPipeline>();
typename SchedThrottlePipeline::Params sched_throttle_pipeline_params;
if (is_dma_warp) {
sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer;
}
if (is_sched_warp) {
sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer;
}
sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads;
sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads;
sched_throttle_pipeline_params.dst_blockid = 0;
sched_throttle_pipeline_params.initializing_warp = 4;
SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle,
sched_throttle_pipeline_params);
SchedThrottlePipelineState sched_pipeline_throttle_consumer_state;
SchedThrottlePipelineState sched_pipeline_throttle_producer_state =
cutlass::make_producer_start_state<SchedThrottlePipeline>();
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
// Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer
if (is_dma_warp) {
// Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access).
cutlass::arch::warpgroup_reg_dealloc<32>();
// Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory.
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize));
// Partition tensors for tiling according to the mainloop and cluster tilers.
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor gB_nk =
local_tile(mB, cluster_tile, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
// Shared memory tensors for pipeline
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()),
sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
// Determine warp/tile positioning
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
// Partition global to local fragments for A and B
Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k)
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk =
tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] =
tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA));
auto [tBgB, tBsB] =
tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
if constexpr (kEnableRHTColQuant) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0],
kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0),
tBsB(_, 0));
}
}
do {
// is_first_wave indicates whether this scheduler wave is the first among a group.
bool is_first_wave = scheduler.is_first_wave();
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_, scheduler.tile_m(), _);
int k_tile = 0;
sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state);
sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state);
++sched_pipeline_throttle_producer_state;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) {
int k_tile_idx_n = scheduler.tile_n_base() + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n),
tAsA(_, write_stage));
}
}
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
// scheduler.advance();
} while (scheduler.is_valid());
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
// This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform.
cutlass::arch::warpgroup_reg_dealloc<32>();
if constexpr (kEnableRHTColQuant) {
// Setup shared memory fragments for A and B tiles.
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()),
sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,
&shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
// Wait until the B (Hadamard) tensor copy is complete
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) {
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_, _, _, read_stage);
auto tCrB_nk = tCrB(_, _, 0, 0);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) {
int accumulator_k_block =
accumulator_pipe_producer_state.index() * EpilogueUnrollFactor;
int tCrA_k_block = k_block * EpilogueUnrollFactor;
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < EpilogueUnrollFactor; i++) {
auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i);
gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state);
barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
}
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
} else if (is_sched_warp) {
// Scheduler warp manages tile assignment and pipeline progress for warps
cutlass::arch::warpgroup_reg_dealloc<32>();
do {
sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state);
sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state);
++sched_pipeline_throttle_consumer_state;
sched_pipeline_producer_state =
scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state);
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
} else if (is_epilogue_col_quant_warp) {
// Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage,
// and writing result tensors/scales to global memory.
cutlass::arch::warpgroup_reg_alloc<192>();
if constexpr (kEnableRHTColQuant) {
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
auto acc_epilogue_pipelined_shape =
append(acc_shape_epilogue, Int<AccumulatorPipelineStageCount / EpilogueUnrollFactor>{});
auto bulk_tmem_epilogue_layout = make_layout(
acc_epilogue_pipelined_shape,
make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler)));
auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr<uint32_t>(), bulk_tmem_epilogue_layout);
// Use 256-bit fragments for aligned bulk stores
static int constexpr FragmentSize = 256 / sizeof_bits_v<TD>;
// Wait for TMEM allocation for this pipeline to finish
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int global_thread_idx = threadIdx.x;
int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup;
// g2s load all global_d_amax
CUTLASS_PRAGMA_NO_UNROLL
for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueColQuantThreadCount) {
shared_storage.global_d_amax[g] =
__ldg(reinterpret_cast<float *>(args.global_d_amax_list[g]));
}
size_t rng_seed = 0;
size_t rng_offset = 0;
// Setup RNG for stochastic rounding
if constexpr (kEnableStochasticRounding) {
rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0;
rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0;
}
int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler));
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = args.split_sections[group_idx];
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
Tensor mD = make_tensor(
cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(args.output_colwise_list[group_idx])),
make_shape(M, cur_N), DStride{}); // (M,packed_N)
Tensor gD_mn =
local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N)
Tensor mSFD = make_tensor(make_gmem_ptr<TSFD>(reinterpret_cast<TSFD *>(
args.output_colwise_scale_inv_list[group_idx])),
sfd_layout);
Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler));
// Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{}));
auto tiled_r2g =
make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TD>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(local_thread_idx);
auto thr_r2g = tiled_r2g.get_slice(local_thread_idx);
cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static constexpr float fp4_max = 6.0f;
static constexpr float fp8_max = 448.0f;
static constexpr float fp4_max_inv = 1.0f / fp4_max;
float c_global_amax_val = shared_storage.global_d_amax[group_idx];
float global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
float global_decode_scale = 1.0f / global_encode_scale;
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
do {
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();
++k_tile) {
int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler);
int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset);
if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
c_global_amax_val = shared_storage.global_d_amax[group_idx];
// update amax
global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
cur_N = args.split_sections[group_idx];
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout =
tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout =
make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// update tensor
mD = make_tensor(cute::subbyte_iterator<TD>(
reinterpret_cast<TD *>(args.output_colwise_list[group_idx])),
make_shape(M, cur_N), DStride{});
gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
mSFD = make_tensor(make_gmem_ptr<TSFD>(reinterpret_cast<TSFD *>(
args.output_colwise_scale_inv_list[group_idx])),
sfd_layout);
gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler));
}
int group_start_offset = args.split_sections_range[group_idx];
int local_tile_n_idx =
(global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler);
Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx);
Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index());
Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc =
make_tensor<ElementAccumulator>(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrD = make_tensor<TD>(shape(tDgD));
Tensor tTR_rAcc_frag =
recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrD_frag = recast<cutlass::Array<TD, FragmentSize>>(coalesce(tDrD));
Tensor src = thr_r2g.retile_S(tDrD);
Tensor dst = thr_r2g.retile_D(tDgD);
Tensor tDgSFD_view = make_tensor(
tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{})));
Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view));
Tensor tDrSFD = make_tensor<TSFD>(shape(tDgSFD));
static int constexpr NumVecs = size(tDgD) / VectorSize;
Tensor tD_rRowSFD_frg = recast<cutlass::Array<TSFD, NumVecs>>(tDrSFD);
// Compute amax and quantization scales for this tile
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>,
true>
amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// Copy from TMEM to registers
copy(tiled_t2r, tDtAcc, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
if constexpr (!kUseFastMath) {
// Downcast to BF16 for bit-wise compatibility with
// unfused kernels
auto convert_accum_to_bf16 =
cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator,
FragmentSize>{};
auto convert_bf16_to_accum =
cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t,
FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array<ElementAccumulator, VectorSize> *>(
tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array<TD, VectorSize> *>(tDrD_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales =
cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
pvscales, global_encode_scale);
}
auto pvscales_cvted =
cutlass::NumericArrayConverter<TSFD, ElementAccumulator, NumVecs>{}(pvscales);
tD_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFD, NumVecs>{}(
tD_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
qpvscale_ups, global_decode_scale);
cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales =
cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(
1.0, qpvscale_scaled);
}
// Prepare stochastic rounding random state if enabled
uint4 random_uint4 = uint4{0, 0, 0, 0};
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
// "Prefetch" a stochastic rounding state for the first tile
if constexpr (kEnableStochasticRounding) {
const size_t rng_sequence = global_thread_idx + k_tile * 512 +
scheduler.get_linear_tile_idx() * K_TILE_MAX * 512;
rng.init(rng_seed, rng_sequence, rng_offset);
}
CUTLASS_PRAGMA_UNROLL
// Apply round/quantize to each fragment, with or without stochastic rounding
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(
acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
if constexpr (kEnableStochasticRounding) {
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale),
*reinterpret_cast<cutlass::Array<uint32_t, 4> *>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TD, ElementAccumulator, VectorSize>{}(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale));
}
}
// Write quantized FP4 tile and dequant scale to gmem
copy(tiled_r2g, src, dst);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD);
}
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
}
} else if (is_epilogue_row_quant_warp) {
// Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage.
cutlass::arch::warpgroup_reg_alloc<136>();
if constexpr (kEnableRowQuant) {
using S2RVectorType = uint128_t;
int global_thread_idx = threadIdx.x;
int local_thread_idx = global_thread_idx % 256;
size_t rng_seed = 0;
size_t rng_offset = 0;
// g2s load all global_a_amax for all groups/tensors
CUTLASS_PRAGMA_NO_UNROLL
for (int g = local_thread_idx; g < args.num_tensors; g += NumEpilogueRowQuantThreadCount) {
shared_storage.global_a_amax[g] =
__ldg(reinterpret_cast<float *>(args.global_a_amax_list[g]));
}
// RNG for stochastic rounding
if constexpr (kEnableStochasticRounding) {
rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0;
rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0;
}
// Input/output tensors/partitions for row quant warp
Tensor mQA =
make_tensor(cute::subbyte_iterator<TQA>(QA), make_layout(make_shape(M, packed_N), dQA));
Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout);
Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_N)
// Swizzled shared memory A tile, with layout
Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>(
coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE)
// Set up layouts for partitioning – tile-by-warp, with vector granularity
using S2RWarpLayout = Layout<Shape<_8, _4>>;
using WarpGroupLayout = Layout<Shape<_1, _8>>;
using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{}));
using S2RValLayout = Layout<Shape<Int<VectorSize>, _1>>;
using S2RAtomA = Copy_Atom<AutoVectorizingCopy, TA>;
using R2GAtomQA = Copy_Atom<AutoVectorizingCopy, TQA>;
using R2GAtomSFA = Copy_Atom<AutoVectorizingCopy, TSFA>;
auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{});
auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{});
auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{});
auto thr_s2r = tiled_s2r.get_slice(local_thread_idx);
auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx);
auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx);
Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE)
// Allocate temporary register tensors for copying quantization => output
Tensor tQArA = make_tensor_like<TA>(
make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N)
Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn);
Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{}));
Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn);
Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{}));
int row_quant_barrier_id = 10;
cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id);
int group_idx = GetGroupIdx(&args, scheduler.tile_n_base() * size<1>(epilogue_tiler));
float a_global_amax_val = shared_storage.global_a_amax[group_idx];
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static constexpr float fp4_max = 6.0f;
static constexpr float fp8_max = 448.0f;
static constexpr float fp4_max_inv = 1.0f / fp4_max;
float global_encode_scale = a_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / a_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
float global_decode_scale = 1.0f / global_encode_scale;
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
auto sfa_converter = cutlass::NumericConverter<TSFA, ElementAccumulator>{};
do {
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) {
int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler);
int cur_group_idx = GetGroupIdx(&args, global_tile_n_offset);
if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
a_global_amax_val = shared_storage.global_a_amax[group_idx];
// Update group quantization parameters/scaling
global_encode_scale = a_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / a_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
}
auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile);
auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile);
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state);
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA);
cutlass::arch::fence_view_async_shared();
mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state);
++mainloop_pipe_consumer_state;
++k_tile;
// static int constexpr NumVecs = size(tQArA) / VectorSize;
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>,
true>
amax_reduction;
auto compute_frgs = reinterpret_cast<cutlass::Array<TA, VectorSize> *>(tQArA.data());
auto output_frgs =
reinterpret_cast<cutlass::Array<TQA, VectorSize> *>(raw_pointer_cast(tQArQA.data()));
Tensor amax =
make_tensor<ElementAccumulator>(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{}));
Tensor pvscales = make_tensor_like<ElementAccumulator>(amax);
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
if constexpr (kEnableStochasticRounding) {
const size_t rng_sequence = global_thread_idx + k_tile * 512 +
scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 +
tiles_in_m * tiles_in_n * K_TILE_MAX * 512;
rng.init(rng_seed, rng_sequence, rng_offset);
}
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) {
auto amax_view = group_modes<1, rank(amax)>(amax);
auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales);
auto compute_frgs_up =
cutlass::NumericArrayConverter<ElementAccumulator, TA, VectorSize>{}(
compute_frgs[v]);
amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up);
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales_view(_0{}, v) = cutlass::multiplies<ElementAccumulator>{}(
amax_view(_0{}, v), global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales_view(_0{}, v) =
cutlass::divides<ElementAccumulator>{}(amax_view(_0{}, v), fp4_max);
pvscales_view(_0{}, v) = cutlass::multiplies<ElementAccumulator>{}(
pvscales_view(_0{}, v), global_encode_scale);
}
filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v));
auto qpvscale_ups =
cutlass::NumericConverter<ElementAccumulator, TSFA>{}(filter(tQArSFA)(v));
auto qpvscale_scaled =
cutlass::multiplies<ElementAccumulator>{}(qpvscale_ups, global_decode_scale);
ElementAccumulator acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales =
cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<ElementAccumulator>{}(1.0, qpvscale_scaled);
}
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(
acc_scales, cutlass::platform::numeric_limits<ElementAccumulator>::max());
uint4 random_uint4 = uint4{0, 0, 0, 0};
if constexpr (kEnableStochasticRounding) {
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs_up, acc_scale),
*reinterpret_cast<cutlass::Array<uint32_t, 4> *>(&random_uint4));
} else {
output_frgs[v] =
cutlass::NumericArrayConverter<TQA, ElementAccumulator, VectorSize>{}(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs_up, acc_scale));
}
}
copy(tiled_r2g_QA, tQArQA, tQAgQA_mn);
copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn));
}
// scheduler.advance();
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
}
} else {
cutlass::arch::warpgroup_reg_dealloc<32>();
}
} // NOLINT(readability/fn_size)
template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableRowQuant,
bool kEnableSwizzleSFOutput, class TA, class TB, class TQA, class TSFA, class TD = TQA,
class TSFD = TSFA, bool kUseFastMath = false>
void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A,
TB const *B, TQA *QA, TSFA *SFA,
MultiAmaxHadamardCastFusionArgs &args,
const size_t *rng_state, uint32_t sm_count,
cudaStream_t stream, int k_tile_size = 1024) {
using namespace cute;
static int constexpr SFVecSize = 16;
static int constexpr RhtTensorSize = 16;
static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16");
using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int<SFVecSize>{}, 0), 0),
make_stride(make_stride(_0{}, _1{}), 0)));
using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int<SFVecSize>{}, 0)),
make_stride(0, make_stride(_0{}, _1{}))));
using SwizzledSFALayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::MN>::SfAtom;
using SwizzledSFDLayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::K>::SfAtom;
using SwizzledSFALayout = decltype(tile_to_shape(
SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}));
using SwizzledSFDLayout = decltype(tile_to_shape(
SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}));
using SFALayout = cute::conditional_t<kEnableSwizzleSFOutput, SwizzledSFALayout, LinearSFALayout>;
using SFDLayout = cute::conditional_t<kEnableSwizzleSFOutput, SwizzledSFDLayout, LinearSFDLayout>;
SFALayout sfa_layout;
SFDLayout sfd_layout;
if constexpr (kEnableSwizzleSFOutput) {
sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{},
make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{});
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{},
make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{});
} else {
sfa_layout = make_layout(
make_shape(make_shape(Int<SFVecSize>{}, hidden_size / SFVecSize), packed_sequence_length),
make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize));
sfd_layout = make_layout(
make_shape(hidden_size, make_shape(Int<SFVecSize>{}, packed_sequence_length / SFVecSize)),
make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{})));
}
// Define shapes (dynamic)
auto M = hidden_size;
auto N = packed_sequence_length;
Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{});
Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{});
Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{});
Tensor tensorSFA = make_tensor(SFA, sfa_layout);
// Define strides (from tensors)
auto dA = stride(tensorA); // (dM,dK)
auto dB = stride(tensorB); // (dN,dK)
auto dD = LayoutRight{}; // (dM,dN)
auto dQA = stride(tensorQA); // (dM,dK)
using ClusterShape = Shape<_1, _1, _1>;
auto cluster_shape = ClusterShape{};
auto cluster_tile_shape = Shape<_128, Int<RhtTensorSize>, Int<RhtTensorSize>>{};
auto cluster_tile_mainloop = Shape<_128, Int<RhtTensorSize>, _128>{};
// Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles
static int constexpr EpilogueUnrollFactor =
size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape);
// Construct the MMA
auto mma = make_tiled_mma(
SM100_MMA_F16BF16_SS<TA, TB, float, size<0>(cluster_tile_shape), size<1>(cluster_tile_shape),
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1, _1>>{});
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B =
partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(
shape<0>(cluster_tile_shape),
shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(
shape<1>(cluster_tile_shape),
shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape));
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::sm100_smem_selector<cute::UMMA::Major::MN, TB,
SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(
mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A =
decltype(shape_div(shape<0>(cluster_tile_mainloop),
shape_div(shape<0>(cluster_tile_mainloop),
size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
static uint32_t constexpr TotalTmemRows = 128;
static uint32_t constexpr Sm100TmemCapacityColumns = 512;
static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns;
static uint32_t constexpr AccumulatorPipelineStageCount =
TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape));
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int SchedulerPipelineStageCount = 4;
static int constexpr MainloopPipelineBytes = sizeof(
typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>,
Shape<_1, _1, _1>>::SharedStorage);
static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount;
static int constexpr SchedulerThrottlePipelineBytes =
sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
static int constexpr SchedulerPipelineBytes =
sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount,
ClusterShape>::SharedStorage);
static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier);
static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB);
static int constexpr AccPipelineBytes = sizeof(
typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
Shape<_1, _1, _1>>::SharedStorage);
static int constexpr TmemBasePtrsBytes = sizeof(uint32_t);
static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes
static int constexpr kBytesPerStage =
cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes;
static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes +
SchedulerPipelineBytes + TmemBasePtrsBytes +
TmemDeallocBytes + BTensorBytes +
AccPipelineBytes; // Reserve for barriers and other uses
static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP),
Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{},
append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1)
auto sD = Layout<_1>{}; // XXX Dummy
auto tma_load_a =
make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma);
auto tma_load_b =
make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma);
// Assert checks on tile sizes -- no predication
assert(M % size<0>(cluster_tile_shape) == 0);
assert(N % size<1>(cluster_tile_shape) == 0);
dim3 dimBlock(512);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(sm_count, 1, 1);
int smem_size = sizeof(
SharedStorage<TA, TB, decltype(sA), decltype(sB), ClusterShape, AccumulatorPipelineStageCount,
EpilogueUnrollFactor, SchedulerPipelineStageCount>);
auto *kernel_ptr = &group_row_col_rht_gemm_device<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape),
decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB,
decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD,
decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma),
AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding,
kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>;
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Allocate workspace and set to zero
void *tile_scheduler_workspace = nullptr;
NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream));
NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream));
// Launch kernel
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
cutlass::Status status = cutlass::launch_kernel_on_cluster(
params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA,
sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, args,
tile_scheduler_workspace, mma, rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed.");
NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream));
}
} // namespace
} // namespace detail
void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tensor *> &output_list,
const size_t *split_sections, size_t num_tensors,
const Tensor &hadamard_matrix_,
QuantizationConfig &quant_config, cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_cast_fusion);
using transformer_engine::detail::kMaxTensorsPerKernel;
using transformer_engine::detail::MultiAmaxHadamardCastFusionArgs;
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor &input = input_.data;
NVTE_CHECK(output_list.size() == num_tensors,
"Number of output tensors should match number of tensors.");
NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel,
"Number of tensors should be less than or equal to ", kMaxTensorsPerKernel);
// construct the multi-tensor args
MultiAmaxHadamardCastFusionArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.split_sections_range[0] = 0;
bool all_has_row_quant = true;
bool all_has_col_quant = true;
void *rowwise_data_base_ptr = nullptr;
void *rowwise_scale_inv_base_ptr = nullptr;
for (size_t i = 0; i < num_tensors; ++i) {
NVTE_CHECK(split_sections[i] % 128 == 0, "component ", i,
" of split_sections should be 128 multiple");
if (split_sections[i] == 0) {
continue;
}
bool has_row_quant = output_list[i]->data.dptr != nullptr;
bool has_col_quant = output_list[i]->columnwise_data.dptr != nullptr;
all_has_row_quant = all_has_row_quant && has_row_quant;
all_has_col_quant = all_has_col_quant && has_col_quant;
// sanity check, the two bool flags cannot be both false
NVTE_CHECK(has_row_quant || has_col_quant,
"At least one of the output tensors must have row or column quant.");
void *amax_rowwise_ptr =
has_row_quant ? reinterpret_cast<void *>(output_list[i]->amax.dptr) : nullptr;
void *amax_colwise_ptr =
has_col_quant ? reinterpret_cast<void *>(output_list[i]->columnwise_amax.dptr) : nullptr;
void *rowwise_data_ptr =
has_row_quant ? reinterpret_cast<void *>(output_list[i]->data.dptr) : nullptr;
void *rowwise_scale_inv_ptr =
has_row_quant ? reinterpret_cast<void *>(output_list[i]->scale_inv.dptr) : nullptr;
if (all_has_row_quant &&
(rowwise_data_base_ptr == nullptr || rowwise_scale_inv_base_ptr == nullptr)) {
rowwise_data_base_ptr = rowwise_data_ptr;
rowwise_scale_inv_base_ptr = rowwise_scale_inv_ptr;
}
void *output_colwise_ptr =
has_col_quant ? reinterpret_cast<void *>(output_list[i]->columnwise_data.dptr) : nullptr;
void *output_colwise_scale_inv_ptr =
has_col_quant ? reinterpret_cast<void *>(output_list[i]->columnwise_scale_inv.dptr)
: nullptr;
kernel_args.global_a_amax_list[kernel_args.num_tensors] = amax_rowwise_ptr;
kernel_args.global_d_amax_list[kernel_args.num_tensors] = amax_colwise_ptr;
kernel_args.output_colwise_list[kernel_args.num_tensors] = output_colwise_ptr;
kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] =
output_colwise_scale_inv_ptr;
kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i];
kernel_args.split_sections_range[kernel_args.num_tensors + 1] =
kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i];
kernel_args.num_tensors++;
}
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (use_stochastic_rounding) {
NVTE_CHECK(quant_config.rng_state != nullptr,
"Enabled stochastic rounding without providing RNG state");
const Tensor &rng_state_tensor = *convertNVTETensorCheck(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TD = cutlass::float_e2m1_t;
using TSFD = cutlass::float_ue4m3_t;
using TQA = TD;
using TSFA = TSFD;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t ndim = input.shape.size();
const size_t n = input.shape[ndim - 1];
size_t m = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
m *= input.shape[i];
}
auto sm_count = transformer_engine::cuda::sm_count();
NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");
int k_tile_size = 1024;
const bool use_swizzle_sf_output = false;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kEnableStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_has_col_quant, kEnableRhtColQuant,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_has_row_quant, kEnableRowQuant,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_swizzle_sf_output, kEnableSwizzleSFOutput,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
if constexpr (kEnableRhtColQuant || kEnableRowQuant) {
detail::group_row_col_rht_gemm_ntt_w_sfc<
kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant,
kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>(
/*packed_sequence_length=*/m, /*hidden_size=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr),
/*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr),
/*args=*/kernel_args,
/*rng_state=*/rng_state, /*sm_count=*/sm_count,
/*stream=*/stream, /*k_tile_size=*/k_tile_size);
} else {
NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=",
kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ").");
}
);););););
}
} // namespace transformer_engine
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor *outputs,
const NVTETensor hadamard_matrix,
const size_t *split_sections,
const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_list(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
output_list[i] = convertNVTETensorCheck(outputs[i]);
}
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors,
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
stream);
}
......@@ -29,7 +29,6 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
......@@ -129,7 +128,8 @@ template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TC, class CStride, class CSmemLayout,
class TSFC,
class TiledMMA,
bool kEnableStochasticRounding = false>
bool kEnableStochasticRounding = false,
bool kUseFastMath = false>
__global__ static
void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
......@@ -426,7 +426,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
static constexpr float fp4_max_inv = 1.0f / fp4_max;
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
......@@ -469,10 +475,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32.
if constexpr (!kUseFastMath) {
// Downcast to BF16 for bit-wise compatibility with unfused
// kernels
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
......@@ -481,14 +490,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
}
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales = cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
}
// Initialize RNG for tile
const size_t rng_sequence
......@@ -532,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void
rht_gemm_ntt_w_sfc(int m, int n,
TA const* A,
......@@ -644,16 +666,15 @@ rht_gemm_ntt_w_sfc(int m, int n,
TC, decltype(dC), decltype(sC),
TSFC,
decltype(mma),
kEnableStochasticRounding>;
kEnableStochasticRounding,
kUseFastMath>;
bool status = cudaFuncSetAttribute(*kernel_ptr,
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
smem_size)
);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape,
......@@ -663,11 +684,12 @@ rht_gemm_ntt_w_sfc(int m, int n,
SFC,
mma, global_amax,
rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void
rht_gemm_ttt_wrapper(int m, int n,
TA const* A,
......@@ -690,7 +712,7 @@ rht_gemm_ttt_wrapper(int m, int n,
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>(
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding, kUseFastMath>(
n, m,
A, B, C,
SFC, global_amax,
......@@ -800,9 +822,12 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>(
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding, kUseFastMath>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
......@@ -813,7 +838,7 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
/*k_tile_size=*/k_tile_size);););
}
} // namespace transformer_engine
......
......@@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_config, const size_t num_tensors,
cudaStream_t stream);
/*! \brief Casts grouped input tensor to quantized output tensors.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] outputs Output quantized tensors.
* \param[in] split_sections Split sections of the input tensor.
* \param[in] num_tensors Number of output tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -86,6 +86,43 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_columnwise(
const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
/*! Whether to enable fast math operations with reduced accuracy.
*
* Optimizations are kernel-specific and they may be applied
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
kNVTEQuantizationConfigNumAttributes
};
......@@ -997,6 +1003,12 @@ class QuantizationConfigWrapper {
&stochastic_rounding, sizeof(bool));
}
/*! \brief Set whether to enable fast math operations */
void set_use_fast_math(bool use_fast_math) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath,
&use_fast_math, sizeof(bool));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided
if (buf == nullptr) {
......@@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(buf, &config_.rng_state, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(buf, &config_.stochastic_rounding, attr_size);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(buf, &config_.use_fast_math, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......@@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(&config_.use_fast_math, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -16,6 +16,7 @@
#include "../extensions.h"
#include "common.h"
#include "common/util/system.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
......@@ -709,124 +710,157 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
return retval;
}
void split_quantize_nvfp4_impl(const TensorWrapper &input,
const std::vector<TensorWrapper> &input_list,
std::vector<TensorWrapper> &output_list,
const std::vector<size_t> &split_sections,
const std::vector<NVFP4Quantizer *> &quantizers) {
// Check tensor lists
const size_t num_tensors = split_sections.size();
NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ",
input_list.size(), ".");
NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors,
" output tensors, but got ", output_list.size(), ".");
NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors,
" NVFP4 quantizers, but got ", quantizers.size(), ".");
// Owns all allocations/wrappers backing quant_config_list[*].set_rng_state(...).
struct StochasticRngStateResources {
at::Tensor rng_states_tensor; // [2 * num_tensors], int64, CUDA
at::Tensor rng_states_tensor_colwise; // optional, same shape/dtype/device
std::vector<TensorWrapper> te_rng_state_list;
std::vector<TensorWrapper> te_rng_state_list_colwise;
bool enabled{false};
bool need_separate_rng_states{false};
bool with_bulk_generate_rng_states{false};
};
// Populates quant_config_list (+ optional colwise list) with rng_state pointers and stochastic flag.
static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper(
size_t num_tensors, bool stochastic_rounding, bool with_bulk_generate_rng_states,
bool need_separate_rng_states,
std::vector<QuantizationConfigWrapper> &quant_config_list_rowwise,
std::vector<QuantizationConfigWrapper> &quant_config_list_colwise) {
// the return object will be used to keep rng states alive
StochasticRngStateResources res;
res.enabled = stochastic_rounding;
res.need_separate_rng_states = need_separate_rng_states;
res.with_bulk_generate_rng_states = with_bulk_generate_rng_states;
if (!stochastic_rounding) return res;
// Basic sanity: caller usually pre-sizes these to num_tensors.
TORCH_CHECK(quant_config_list_rowwise.size() == num_tensors,
"quant_config_list_rowwise must be sized to num_tensors");
if (need_separate_rng_states) {
TORCH_CHECK(quant_config_list_colwise.size() == num_tensors,
"quant_config_list_colwise must be sized to num_tensors when "
"need_separate_rng_states=true");
}
const size_t rng_elts_per_thread =
res.with_bulk_generate_rng_states ? (1024 * num_tensors) : 1024;
// Trivial cases
if (num_tensors == 0) {
return;
}
if (input.numel() == 0) {
for (const auto &tensor : input_list) {
NVTE_CHECK(tensor.numel() == 0,
"Input tensor has zero elements but got split with non-zero elements");
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
res.rng_states_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
if (need_separate_rng_states) {
res.rng_states_tensor_colwise = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
}
return;
res.te_rng_state_list.reserve(num_tensors);
if (need_separate_rng_states) res.te_rng_state_list_colwise.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// Rowwise RNG state
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr = static_cast<int64_t *>(res.rng_states_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args, rng_state_ptr);
res.te_rng_state_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data());
quant_config_list_rowwise[i].set_stochastic_rounding(true);
// Colwise RNG state (only if you truly need a different sequence)
if (need_separate_rng_states) {
// re-initialize philox_args for colwise RNG state
at::PhiloxCudaState philox_args_col = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr_colwise =
static_cast<int64_t *>(res.rng_states_tensor_colwise.data_ptr()) + i * 2;
philox_unpack(philox_args_col, rng_state_ptr_colwise);
res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_ptr_colwise), std::vector<size_t>{2}, DType::kInt64));
quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data());
quant_config_list_colwise[i].set_stochastic_rounding(true);
}
// Assume all quantizers have identical config
const auto &quantizer = *quantizers.front();
NVTE_CHECK(!quantizer.with_2d_quantization,
"NVFP4 split-quantize does not support 2D quantization");
NVTE_CHECK(!quantizer.with_amax_reduction,
"NVFP4 split-quantize does not support amax reduction");
// break the loop if we are using bulk generate rng states
if (res.with_bulk_generate_rng_states) break;
}
// Check input tensor shape
const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1;
NVTE_CHECK(input_last_dim % 128 == 0,
"NVFP4 multi-quantize requires inner dim to be multiple of 128.");
return res;
}
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT)
void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
const std::vector<TensorWrapper> &input_list,
std::vector<TensorWrapper> &output_list,
const std::vector<size_t> &split_sections,
const std::vector<NVFP4Quantizer *> &quantizers,
cudaStream_t stream) {
const size_t num_tensors = split_sections.size();
const auto &quantizer = *quantizers.front();
// Objects for TE C API
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<QuantizationConfigWrapper> quant_config_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
quant_config_list.emplace_back(QuantizationConfigWrapper());
}
// Stochastic rounding
// When both rowwise and columnwise quantization are used,
// we need separate RNG states for each to ensure they use different random numbers.
std::vector<TensorWrapper> te_rng_state_list;
std::vector<TensorWrapper> te_rng_state_columnwise_list;
std::vector<QuantizationConfigWrapper> quant_config_columnwise_list;
at::Tensor rng_states_tensor;
at::Tensor rng_states_columnwise_tensor;
const bool need_separate_columnwise_rng =
quantizer.stochastic_rounding && quantizer.with_rht && quantizer.columnwise_usage;
if (quantizer.stochastic_rounding) {
// TODO(zhongbo): remove the for loop of generating rng states with a single call
// with rng_elts_per_thread = 1024 * num_tensors
// Change to the bulk generate rng states api when grouped quantize is available
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
rng_states_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
// trigger the row-col fusion when the split-sections shapes are all 128 aligned for max performance
bool all_aligned_token_dim =
std::all_of(split_sections.begin(), split_sections.end(),
[](size_t split_section) { return split_section % 128 == 0; });
// in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice
// so that rowwise and colwise will have different random numbers
bool need_separate_rng_states =
(!all_aligned_token_dim) && quantizer.rowwise_usage && quantizer.columnwise_usage;
// Allocate columnwise RNG resources when separate RNG is needed
if (need_separate_columnwise_rng) {
rng_states_columnwise_tensor = torch::empty({static_cast<int64_t>(2 * num_tensors)}, opts);
// Objects for TE C API
std::vector<QuantizationConfigWrapper> quant_config_list;
std::vector<QuantizationConfigWrapper> quant_config_list_colwise;
for (size_t i = 0; i < num_tensors; ++i) {
quant_config_columnwise_list.emplace_back(QuantizationConfigWrapper());
}
quant_config_list.emplace_back(QuantizationConfigWrapper());
quant_config_list_colwise.emplace_back(QuantizationConfigWrapper());
}
for (size_t i = 0; i < num_tensors; ++i) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// Generate RNG state for rowwise quantization
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_ptr = static_cast<int64_t *>(rng_states_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args, rng_state_ptr);
te_rng_state_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_list[i].set_rng_state(te_rng_state_list[i].data());
quant_config_list[i].set_stochastic_rounding(true);
// Generate separate RNG state for columnwise quantization
if (need_separate_columnwise_rng) {
at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread);
int64_t *rng_state_columnwise_ptr =
static_cast<int64_t *>(rng_states_columnwise_tensor.data_ptr()) + i * 2;
philox_unpack(philox_args_columnwise, rng_state_columnwise_ptr);
te_rng_state_columnwise_list.push_back(makeTransformerEngineTensor(
static_cast<void *>(rng_state_columnwise_ptr), std::vector<size_t>{2}, DType::kInt64));
quant_config_columnwise_list[i].set_rng_state(te_rng_state_columnwise_list[i].data());
quant_config_columnwise_list[i].set_stochastic_rounding(true);
// this is true because we have already built grouped kernels for rowwise and colwise quantization with RHT
bool with_bulk_generate_rng_states = true;
// Stochastic rounding
bool need_stochastic_rounding = quantizer.stochastic_rounding;
auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper(
num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states,
need_separate_rng_states, quant_config_list, quant_config_list_colwise);
// Enable NVFP4 kernels to use math operations that sacrifice
// accuracy for performance. These optimizations are experimental
// and inconsistently implemented.
const auto use_fast_math = transformer_engine::getenv<bool>("NVTE_USE_FAST_MATH");
if (use_fast_math) {
for (auto &config : quant_config_list) {
config.set_use_fast_math(true);
}
for (auto &config : quant_config_list_colwise) {
config.set_use_fast_math(true);
}
}
// Perform multi-tensor quantization
if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data
// Check that config is supported
NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input");
auto &quant_config_list_colwise_to_use =
need_separate_rng_states ? quant_config_list_colwise : quant_config_list;
// Compute amaxes
if (quantizer.with_post_rht_amax) {
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE({
nvte_group_hadamard_transform_amax(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream);
});
} else {
// RHT is enabled, but amax is pre-RHT amax
NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax");
......@@ -837,51 +871,64 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
"RHT matrix is not available.");
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix);
// Quantize tensors individually
NVTE_SCOPED_GIL_RELEASE({
for (size_t i = 0; i < num_tensors; i++) {
if (input_list[i].numel() == 0) {
continue; // Skip tensors with no elements
}
// Direct NVFP4 quantization for row-wise data
if (all_aligned_token_dim) {
// call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose
nvte_group_hadamard_transform_cast_fusion(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream);
} else {
// Separate quantization for rowwise usage and columnwise usage
// Rowwise quantization fusion with grouped version
if (quantizer.rowwise_usage) {
auto out_rowwise_data = output_list[i].get_rowwise_data();
auto out_rowwise_scale_inv = output_list[i].get_rowwise_scale_inv();
auto out_rowwise_amax = output_list[i].get_amax();
TensorWrapper out_rowwise(output_list[i].scaling_mode());
out_rowwise.set_rowwise_data(out_rowwise_data.data_ptr,
static_cast<DType>(out_rowwise_data.dtype),
out_rowwise_data.shape);
out_rowwise.set_rowwise_scale_inv(out_rowwise_scale_inv.data_ptr,
static_cast<DType>(out_rowwise_scale_inv.dtype),
out_rowwise_scale_inv.shape);
out_rowwise.set_amax(out_rowwise_amax.data_ptr,
static_cast<DType>(out_rowwise_amax.dtype), out_rowwise_amax.shape);
nvte_quantize_v2(input_list[i].data(), out_rowwise.data(), quant_config_list[i], stream);
}
// RHT + NVFP4 quantize for column-wise data
std::vector<TensorWrapper> out_identity_list;
std::vector<NVTETensor> nvte_tensor_out_identity_list;
for (size_t i = 0; i < num_tensors; i++) {
bool is_empty_split = input_list[i].numel() == 0;
TensorWrapper out_identity(output_list[i].scaling_mode());
auto out_identity_data = output_list[i].get_rowwise_data();
auto out_identity_scale_inv = output_list[i].get_rowwise_scale_inv();
auto out_identity_amax = output_list[i].get_amax();
if (!is_empty_split) {
out_identity.set_rowwise_data(out_identity_data.data_ptr,
static_cast<DType>(out_identity_data.dtype),
out_identity_data.shape);
out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr,
static_cast<DType>(out_identity_scale_inv.dtype),
out_identity_scale_inv.shape);
out_identity.set_amax(out_identity_amax.data_ptr,
static_cast<DType>(out_identity_amax.dtype),
out_identity_amax.shape);
}
out_identity_list.emplace_back(std::move(out_identity));
nvte_tensor_out_identity_list.push_back(out_identity_list.back().data());
}
nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_out_identity_list.data(),
split_sections.data(), num_tensors, quant_config_list[0],
stream);
}
// Columnwise RHT quantization fusion with grouped version
if (quantizer.columnwise_usage) {
// Get the output column-wise data, scale_inv, and amax
std::vector<TensorWrapper> out_transpose_list;
std::vector<NVTETensor> nvte_tensor_out_transpose_list;
for (size_t i = 0; i < num_tensors; i++) {
bool is_empty_split = input_list[i].numel() == 0;
auto out_columnwise_data = output_list[i].get_columnwise_data();
auto out_columnwise_scale_inv = output_list[i].get_columnwise_scale_inv();
auto out_columnwise_amax = output_list[i].get_columnwise_amax();
// Flatten column-wise data to 2D
// Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout.
TensorWrapper out_transpose(output_list[i].scaling_mode());
if (!is_empty_split) {
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t i = 1; i < colwise_data_shape.ndim; ++i) {
last_dim *= colwise_data_shape.data[i];
for (size_t j = 1; j < colwise_data_shape.ndim; ++j) {
last_dim *= colwise_data_shape.data[j];
}
colwise_data_shape_2d.push_back(last_dim);
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper out_transpose(output_list[i].scaling_mode());
out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
colwise_data_shape_2d);
......@@ -891,19 +938,58 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
out_transpose.set_amax(out_columnwise_amax.data_ptr,
static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape);
}
out_transpose_list.emplace_back(std::move(out_transpose));
nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data());
}
nvte_group_hadamard_transform_cast_fusion_columnwise(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_out_transpose_list.data()),
rht_matrix_nvte.data(), split_sections.data(), num_tensors,
quant_config_list_colwise_to_use[0], stream);
}
}
}
// RHT + NVFP4 quantize kernel
// Use separate RNG state for columnwise to ensure different random numbers than rowwise
auto &columnwise_quant_config =
need_separate_columnwise_rng ? quant_config_columnwise_list[i] : quant_config_list[i];
nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(),
rht_matrix_nvte.data(),
columnwise_quant_config, stream);
void split_quantize_nvfp4_impl_helper(const TensorWrapper &input,
const std::vector<TensorWrapper> &input_list,
std::vector<TensorWrapper> &output_list,
const std::vector<size_t> &split_sections,
const std::vector<NVFP4Quantizer *> &quantizers,
cudaStream_t stream) {
const size_t num_tensors = input_list.size();
const auto &quantizer = *quantizers.front();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
}
// In this case without RHT, the rowwise and colwise quantization are fused
// we don't need separate rng states for rowwise and colwise
bool need_separate_rng_states = false;
// Objects for TE C API
std::vector<QuantizationConfigWrapper> quant_config_list;
for (size_t i = 0; i < num_tensors; ++i) {
quant_config_list.emplace_back(QuantizationConfigWrapper());
}
});
} else { // NVFP4 quantize
// TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized
// so that we can generate all rng states at once
bool with_bulk_generate_rng_states = false;
bool need_stochastic_rounding = quantizer.stochastic_rounding;
// place holder for colwise rng states, which are not needed in this case
std::vector<QuantizationConfigWrapper> dummy_quant_config_list_colwise;
auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper(
num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states,
need_separate_rng_states, quant_config_list,
dummy_quant_config_list_colwise); // colwise rng states are not needed in this case
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for input too
......@@ -919,16 +1005,13 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer");
output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_SCOPED_GIL_RELEASE({
nvte_group_amax(input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
split_sections.data(), num_tensors, stream);
});
for (size_t i = 0; i < num_tensors; i++) {
output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector<size_t>{1});
}
// Quantize tensors individually
NVTE_SCOPED_GIL_RELEASE({
for (size_t i = 0; i < num_tensors; i++) {
// skip this round if input is empty
if (input_list[i].numel() == 0) {
......@@ -936,8 +1019,70 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
}
nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream);
}
}
void split_quantize_nvfp4_impl(const TensorWrapper &input,
const std::vector<TensorWrapper> &input_list,
std::vector<TensorWrapper> &output_list,
const std::vector<size_t> &split_sections,
const std::vector<NVFP4Quantizer *> &quantizers) {
// Check tensor lists
const size_t num_tensors = split_sections.size();
NVTE_CHECK(input_list.size() == num_tensors, "Expected ", num_tensors, " input tensors, but got ",
input_list.size(), ".");
NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors,
" output tensors, but got ", output_list.size(), ".");
NVTE_CHECK(quantizers.size() == num_tensors, "Expected ", num_tensors,
" NVFP4 quantizers, but got ", quantizers.size(), ".");
// sanity check all the quantizers have the same scaling mode
bool all_same_scaling_mode =
std::all_of(quantizers.begin(), quantizers.end(), [&](const NVFP4Quantizer *quantizer) {
return quantizer->get_scaling_mode() == quantizers.front()->get_scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "All quantizers must have the same scaling mode");
// Trivial cases
if (num_tensors == 0) {
return;
}
if (input.numel() == 0) {
for (const auto &tensor : input_list) {
NVTE_CHECK(tensor.numel() == 0,
"Input tensor has zero elements but got split with non-zero elements");
}
return;
}
// Assume all quantizers have identical config
const auto &quantizer = *quantizers.front();
NVTE_CHECK(!quantizer.with_2d_quantization,
"NVFP4 split-quantize does not support 2D quantization");
NVTE_CHECK(!quantizer.with_amax_reduction,
"NVFP4 split-quantize does not support amax reduction");
// Check input tensor shape
const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1;
NVTE_CHECK(input_last_dim % 128 == 0,
"NVFP4 multi-quantize requires inner dim to be multiple of 128.");
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Perform multi-tensor quantization
NVTE_SCOPED_GIL_RELEASE({
if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data
// Check that config is supported
NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input");
// Fuse the rowwise and colwise into one when the kernel is ready
split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections,
quantizers, stream);
} else { // NVFP4 quantize
// Fuse the rowwise and colwise into one when the kernel is ready
split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers,
stream);
}
});
}
} // namespace
......
......@@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
}
}
// Restriction for the RHT cast fusion kernel.
// Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
......
......@@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int:
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 64
return 128
return 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment