Unverified Commit eb8e792b authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform (#2411)



* rowwise colwise RHT group quant v1
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove local array RW
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* change wait_barrier
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fast math options
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* use mult to replace div
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* format
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* bulk move random states
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* greptile
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert to use divides
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* avoid fp32 bf16 round-trip in RHT cast fusion
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* integrate row col rht fusion, functional
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* numerics aligned
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* style
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove device sync
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* 128 padding
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert colwise rng state creation because of row-col fused kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix CI, linter
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* refactor RS for generating two random values
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid invalid configs with templated kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix acc pipeline init with 0 arrival count
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* restore rowwise-only mode
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* switch to dynamic atomic scheduler
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid instantiating group RHT+cast kernel without row-wise or col-wise output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include fast math option in quantization config
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings and review nits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TE license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug where kernel is always launched on stream
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore BF16 intermediate downcast in fused RHT-cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix numerical test of grouped kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Make sure row-wise and col-wise quantization use different RNG seeds
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Restore autoformatter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 47902e96
...@@ -53,7 +53,7 @@ ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \ ...@@ -53,7 +53,7 @@ ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
--set=full \ --set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \ --kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \ -s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
""" """
...@@ -173,7 +173,9 @@ def benchmark_linear( ...@@ -173,7 +173,9 @@ def benchmark_linear(
return timing_ms return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None): def run_benchmark_linear(
mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
):
data = [] data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark" assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
...@@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ...@@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
device = "cuda" device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0 m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
# Bias is not supported for GroupedLinear benchmark # Bias is not supported for GroupedLinear benchmark
bias = None bias = None
# Run the benchmark # Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}") print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}") print(f"m_splits: {m_splits}")
print(f"fwd_only: {fwd_only}")
grouped_fwd_bwd_timing_ms = benchmark_linear( grouped_fwd_bwd_timing_ms = benchmark_linear(
x, x,
...@@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ...@@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
m_splits, m_splits,
bias, bias,
recipe_name, recipe_name,
mode="fwd_bwd", mode="fwd_only" if fwd_only else "fwd_bwd",
num_gemms=num_gemms, num_gemms=num_gemms,
) )
...@@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ...@@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
] ]
) )
timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"
df = pd.DataFrame( df = pd.DataFrame(
data=data, data=data,
columns=[ columns=[
...@@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None ...@@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
"n", "n",
"recipe", "recipe",
"num_gemms", "num_gemms",
"grouped_fwd_bwd_time_ms", timing_notation,
], ],
) )
...@@ -234,7 +238,7 @@ if __name__ == "__main__": ...@@ -234,7 +238,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode") parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument( parser.add_argument(
"--output_dir", "--output-dir",
type=str, type=str,
default="benchmark_output/", default="benchmark_output/",
help="output path for report", help="output path for report",
...@@ -266,6 +270,12 @@ if __name__ == "__main__": ...@@ -266,6 +270,12 @@ if __name__ == "__main__":
default=2048, default=2048,
help="Output dimension to use, default is 2048", help="Output dimension to use, default is 2048",
) )
parser.add_argument(
"--fwd-only",
action="store_true",
default=False,
help="Run forward pass only, default is both forward and backward passes",
)
args = parser.parse_args() args = parser.parse_args()
jagged_input_splits = None jagged_input_splits = None
...@@ -297,7 +307,7 @@ if __name__ == "__main__": ...@@ -297,7 +307,7 @@ if __name__ == "__main__":
if jagged_input_splits is not None: if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)] num_gemms_list = [len(jagged_input_splits)]
token_dim_list = [65536] token_dim_list = [16384, 32768, 65536, 98304]
hidden_dim_list = [7168] hidden_dim_list = [7168]
output_dim_list = [2048] output_dim_list = [2048]
...@@ -371,7 +381,8 @@ if __name__ == "__main__": ...@@ -371,7 +381,8 @@ if __name__ == "__main__":
recipe_name, recipe_name,
use_bias, use_bias,
num_gemms=num_gemms, num_gemms=num_gemms,
m_splits=jagged_input_splits, m_splits_provided=jagged_input_splits,
fwd_only=args.fwd_only,
) )
df_linears = pd.concat([df_linears, df]) df_linears = pd.concat([df_linears, df])
......
...@@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference( ...@@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference(
for i in range(len(x_qx)): for i in range(len(x_qx)):
if split_sections[i] == 0: if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out # then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
...@@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference( ...@@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference(
# assert with zero tolerance # assert with zero tolerance
for i in range(len(x_qx_t)): for i in range(len(x_qx_t)):
if split_sections[i] == 0: if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out # then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
...@@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference( ...@@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference(
(1024, 256), (1024, 256),
# larger sizes # larger sizes
(8192, 1024), (8192, 1024),
(16384, 8192),
(16384, 16384), (16384, 16384),
], ],
) )
......
...@@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources ...@@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/group_hadamard_transform.cu hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
......
...@@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, ...@@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
} }
} }
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "../core/common.cuh" #include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh" #include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh"
...@@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens ...@@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
} }
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");
// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch } // namespace dispatch
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -394,6 +394,7 @@ struct QuantizationConfig { ...@@ -394,6 +394,7 @@ struct QuantizationConfig {
NVTETensor rng_state = nullptr; NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false; bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false; bool stochastic_rounding = false;
bool use_fast_math = false;
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales sizeof(bool), // force_pow_2_scales
...@@ -402,7 +403,8 @@ struct QuantizationConfig { ...@@ -402,7 +403,8 @@ struct QuantizationConfig {
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding sizeof(bool), // stochastic_rounding
sizeof(bool) // use_fast_math
}; };
}; };
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
namespace detail {
// Producer-consumer pipeline implementation
// for UMMA producer. In this case, UMMA barrier arrives are used
// by producer_commit. Use case, accumulator generation as
// the result of MMA instructions.
template <int Stages_, class ClusterShape = Shape<int, int, _1>,
class AtomThrShape_MNK_ = Shape<_1, _1, _1> >
class CustomizedPipelineTmaUmmaAsync {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaAsync<Stages>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params,
ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1
if (cute::size(cluster_shape) > 1) {
multicast_consumer_arrival_count =
((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) *
params.num_consumers;
}
CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 &&
"Multicast consumer arrival count must be non-zero");
CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero");
cutlass::arch::detail::initialize_barrier_array_pair_aligned<
decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt,
multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape,
dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
} else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
// Constructor by default initializes barriers and calculates masks.
// These operations can be explicity deferred by specifying InitBarriers and InitMasks.
// If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called.
template <typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE CustomizedPipelineTmaUmmaAsync(SharedStorage& storage, Params params,
ClusterShape cluster_shape, InitBarriers = {},
InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, InitMasks{}),
params_(params),
empty_barrier_ptr_(&storage.empty_barrier_[0]),
full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> ||
cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> ||
cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
////////////////////
// Producer APIs
////////////////////
// Four member functions are always used in pairs:
//
// * producer_try_acquire and producer_acquire, and
// * consumer_try_wait and consumer_wait.
//
// The two functions with "try" in their names are called "try" functions,
// and the other two are conceptually "finalize" functions.
// The "try" function in each pair starts the process of waiting on the barrier to flip.
// It opportunistically waits for an implementation-dependent timeout.
// Whether or not the barrier has flipped yet, the try function will return a token.
// If the token indicates that the barrier has not flipped,
// then the token must be passed into the corresponding "finalize" function.
// The finalize function will then block until the barrier has flipped.
// If the token indicates that the barrier _has_ flipped,
// then it is still correct to pass it into the finalize function.
// The finalize function will return immediately in that case.
CUTLASS_DEVICE
ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) {
return impl_.producer_try_acquire(state, skip_wait);
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state,
ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
impl_.producer_expect_transaction(state, transaction_bytes);
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(PipelineState state, uint32_t bytes) { impl_.producer_commit(state, bytes); }
// Prevents early exit of producer blocks in Cluster.
// This should be called once before kernel exits.
CUTLASS_DEVICE
void producer_tail(PipelineState state) { impl_.producer_tail(state); }
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
////////////////////
// Consumer APIs
////////////////////
CUTLASS_DEVICE
ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) {
return impl_.consumer_try_wait(state, skip_wait);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state,
ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void umma_consumer_release(PipelineState state) { umma_consumer_release(state.index(), false); }
CUTLASS_DEVICE
void consumer_release(PipelineState state) { impl_.consumer_release(state); }
private:
Impl impl_;
Params params_;
EmptyBarrier* empty_barrier_ptr_;
FullBarrier* full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notified.
CUTLASS_DEVICE
void umma_consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
// {$nv-release-never begin}
// TODO: Needs to be updated once Blackwell specialized pipeline is implemented.
// XMMA style bar_peek will be tested. We will need to revisit skip interface and
// what skip means when we have bar_peek functionality.
// A separate MR will implement MMA_2x1SM specialized pipeline.
// {$nv-release-never end}
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
} else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> && size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
} else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
} // namespace detail
} // namespace cutlass
#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
...@@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector<Tensor*>& o ...@@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector<Tensor*>& o
} }
// Multi zero out multiple amaxes if needed // Multi zero out multiple amaxes if needed
// Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel // Currently don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block // let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block
dim3 block_setup_amax(kMaxTensorsPerKernel); dim3 block_setup_amax(kMaxTensorsPerKernel);
dim3 grid_setup_amax(1); dim3 grid_setup_amax(1);
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h" #include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp" #include "cutlass/util/print_error.hpp"
// clang-format off // clang-format off
...@@ -129,7 +128,8 @@ template <class MShape, class NShape, class KShape, class ClusterTileShape, ...@@ -129,7 +128,8 @@ template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TC, class CStride, class CSmemLayout, class TC, class CStride, class CSmemLayout,
class TSFC, class TSFC,
class TiledMMA, class TiledMMA,
bool kEnableStochasticRounding = false> bool kEnableStochasticRounding = false,
bool kUseFastMath = false>
__global__ static __global__ static
void void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
...@@ -426,7 +426,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ...@@ -426,7 +426,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale; const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
static constexpr float fp4_max_inv = 1.0f / fp4_max;
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
do { do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
...@@ -469,10 +475,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ...@@ -469,10 +475,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
++accumulator_pipe_consumer_state; ++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32. if constexpr (!kUseFastMath) {
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{}; // Downcast to BF16 for bit-wise compatibility with unfused
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{}; // kernels
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data()); auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data()); auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
...@@ -481,14 +490,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ...@@ -481,14 +490,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
} }
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max); if constexpr (kUseFastMath) {
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale); // Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
}
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales); auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted; tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{})); auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale); auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled); cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales = cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
}
// Initialize RNG for tile // Initialize RNG for tile
const size_t rng_sequence const size_t rng_sequence
...@@ -532,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, ...@@ -532,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
// B: 16 x 16: row-major // B: 16 x 16: row-major
// C: m x n: row-major // C: m x n: row-major
// SFC: m x (n/16): row-major // SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false> template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void void
rht_gemm_ntt_w_sfc(int m, int n, rht_gemm_ntt_w_sfc(int m, int n,
TA const* A, TA const* A,
...@@ -644,16 +666,15 @@ rht_gemm_ntt_w_sfc(int m, int n, ...@@ -644,16 +666,15 @@ rht_gemm_ntt_w_sfc(int m, int n,
TC, decltype(dC), decltype(sC), TC, decltype(dC), decltype(sC),
TSFC, TSFC,
decltype(mma), decltype(mma),
kEnableStochasticRounding>; kEnableStochasticRounding,
kUseFastMath>;
bool status = cudaFuncSetAttribute(*kernel_ptr, NVTE_CHECK_CUDA(
cudaFuncAttributeMaxDynamicSharedMemorySize, cudaFuncSetAttribute(*kernel_ptr,
smem_size); cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size)
);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr) (*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>> <<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape, (M, N, k_tile_size, cga_tile_shape,
...@@ -663,11 +684,12 @@ rht_gemm_ntt_w_sfc(int m, int n, ...@@ -663,11 +684,12 @@ rht_gemm_ntt_w_sfc(int m, int n,
SFC, SFC,
mma, global_amax, mma, global_amax,
rng_state); rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
// this function is used to wrap the rht_gemm_ntt_w_sfc function // this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A //to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false> template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void void
rht_gemm_ttt_wrapper(int m, int n, rht_gemm_ttt_wrapper(int m, int n,
TA const* A, TA const* A,
...@@ -690,7 +712,7 @@ rht_gemm_ttt_wrapper(int m, int n, ...@@ -690,7 +712,7 @@ rht_gemm_ttt_wrapper(int m, int n,
// B: 16 x 16: row-major // B: 16 x 16: row-major
// C: n x m: row-major // C: n x m: row-major
// SFC: n x (m/16): row-major // SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>( rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding, kUseFastMath>(
n, m, n, m,
A, B, C, A, B, C,
SFC, global_amax, SFC, global_amax,
...@@ -800,20 +822,23 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out ...@@ -800,20 +822,23 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
} else if (m < 1024 || n < 1024) { } else if (m < 1024 || n < 1024) {
k_tile_size = 512; k_tile_size = 512;
} }
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding, use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>( TRANSFORMER_ENGINE_SWITCH_CONDITION(
/*m=*/m, quant_config.use_fast_math, kUseFastMath,
/*n=*/n, detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding, kUseFastMath>(
/*A=*/reinterpret_cast<TA const *>(input.dptr), /*m=*/m,
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr), /*n=*/n,
/*C=*/reinterpret_cast<TC *>(output_t.dptr), /*A=*/reinterpret_cast<TA const *>(input.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr), /*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr), /*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*rng_state=*/rng_state, /*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*sm_count=*/sm_count, /*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*stream=*/stream, /*rng_state=*/rng_state,
/*k_tile_size=*/k_tile_size);); /*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size);););
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, ...@@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_config, const size_t num_tensors, const NVTEQuantizationConfig quant_config, const size_t num_tensors,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Casts grouped input tensor to quantized output tensors.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] outputs Output quantized tensors.
* \param[in] split_sections Split sections of the input tensor.
* \param[in] num_tensors Number of output tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -86,6 +86,43 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp ...@@ -86,6 +86,43 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int random_sign_mask, int random_sign_mask_t, int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream); cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_columnwise(
const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute { ...@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization = 5, kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */ /*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6, kNVTEQuantizationConfigStochasticRounding = 6,
/*! Whether to enable fast math operations with reduced accuracy.
*
* Optimizations are kernel-specific and they may be applied
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
kNVTEQuantizationConfigNumAttributes kNVTEQuantizationConfigNumAttributes
}; };
...@@ -997,6 +1003,12 @@ class QuantizationConfigWrapper { ...@@ -997,6 +1003,12 @@ class QuantizationConfigWrapper {
&stochastic_rounding, sizeof(bool)); &stochastic_rounding, sizeof(bool));
} }
/*! \brief Set whether to enable fast math operations */
void set_use_fast_math(bool use_fast_math) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath,
&use_fast_math, sizeof(bool));
}
private: private:
/*! \brief Wrapped NVTEQuantizationConfig. */ /*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr; NVTEQuantizationConfig config_ = nullptr;
......
...@@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
// Write attribute size // Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
*size_written = attr_size; if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided // Return immediately if buffer is not provided
if (buf == nullptr) { if (buf == nullptr) {
...@@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size); std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break; break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(buf, &config_.rng_state, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(buf, &config_.stochastic_rounding, attr_size);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(buf, &config_.use_fast_math, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
...@@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigStochasticRounding: case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size); std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break; break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(&config_.use_fast_math, buf, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
......
...@@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
} }
} }
// Restriction for the RHT cast fusion kernel. // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion = bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
......
...@@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int: ...@@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int:
if recipe.mxfp8(): if recipe.mxfp8():
return 32 return 32
if recipe.nvfp4(): if recipe.nvfp4():
return 64 return 128
return 16 return 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment