Unverified Commit eb8e792b authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform (#2411)



* rowwise colwise RHT group quant v1
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove local array RW
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* change wait_barrier
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fast math options
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* use mult to replace div
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* format
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* bulk move random states
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* greptile
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert to use divides
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* avoid fp32 bf16 round-trip in RHT cast fusion
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* trigger fastmath by toggle NVTE_RHT_CAST_FUSION_USE_FAST_MATH
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* integrate row col rht fusion, functional
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* numerics aligned
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* style
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* remove device sync
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* 128 padding
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* revert colwise rng state creation because of row-col fused kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* fix CI, linter
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* refactor RS for generating two random values
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid invalid configs with templated kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix acc pipeline init with 0 arrival count
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* restore rowwise-only mode
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* switch to dynamic atomic scheduler
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Avoid instantiating group RHT+cast kernel without row-wise or col-wise output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Include fast math option in quantization config
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings and review nits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use TE license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug where kernel is always launched on stream
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Restore BF16 intermediate downcast in fused RHT-cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix numerical test of grouped kernel
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Make sure row-wise and col-wise quantization use different RNG seeds
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Restore autoformatter
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 47902e96
......@@ -53,7 +53,7 @@ ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
--set=full \
--kernel-name "GroupHadamardAmaxTmaKernel" \
-s 5 -c 5 \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
"""
......@@ -173,7 +173,9 @@ def benchmark_linear(
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None):
def run_benchmark_linear(
mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
......@@ -182,14 +184,14 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms if m_splits is None else m_splits
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
print(f"m_splits: {m_splits}")
print(f"fwd_only: {fwd_only}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
......@@ -197,7 +199,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
mode="fwd_only" if fwd_only else "fwd_bwd",
num_gemms=num_gemms,
)
......@@ -213,6 +215,8 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
]
)
timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"
df = pd.DataFrame(
data=data,
columns=[
......@@ -221,7 +225,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits=None
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
timing_notation,
],
)
......@@ -234,7 +238,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
"--output-dir",
type=str,
default="benchmark_output/",
help="output path for report",
......@@ -266,6 +270,12 @@ if __name__ == "__main__":
default=2048,
help="Output dimension to use, default is 2048",
)
parser.add_argument(
"--fwd-only",
action="store_true",
default=False,
help="Run forward pass only, default is both forward and backward passes",
)
args = parser.parse_args()
jagged_input_splits = None
......@@ -297,7 +307,7 @@ if __name__ == "__main__":
if jagged_input_splits is not None:
num_gemms_list = [len(jagged_input_splits)]
token_dim_list = [65536]
token_dim_list = [16384, 32768, 65536, 98304]
hidden_dim_list = [7168]
output_dim_list = [2048]
......@@ -371,7 +381,8 @@ if __name__ == "__main__":
recipe_name,
use_bias,
num_gemms=num_gemms,
m_splits=jagged_input_splits,
m_splits_provided=jagged_input_splits,
fwd_only=args.fwd_only,
)
df_linears = pd.concat([df_linears, df])
......
......@@ -198,7 +198,7 @@ def check_group_quantization_nvfp4_versus_reference(
for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
......@@ -221,7 +221,7 @@ def check_group_quantization_nvfp4_versus_reference(
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same same and dtype because the buffer won't be zero out
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
......@@ -247,6 +247,7 @@ def check_group_quantization_nvfp4_versus_reference(
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 8192),
(16384, 16384),
],
)
......
......@@ -174,6 +174,8 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
......
......@@ -100,3 +100,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
......@@ -19,6 +19,7 @@
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
......@@ -320,6 +321,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");
// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
......@@ -394,6 +394,7 @@ struct QuantizationConfig {
NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
bool use_fast_math = false;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
......@@ -402,7 +403,8 @@ struct QuantizationConfig {
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding
sizeof(bool), // stochastic_rounding
sizeof(bool) // use_fast_math
};
};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
namespace detail {
// Producer-consumer pipeline implementation
// for UMMA producer. In this case, UMMA barrier arrives are used
// by producer_commit. Use case, accumulator generation as
// the result of MMA instructions.
template <int Stages_, class ClusterShape = Shape<int, int, _1>,
class AtomThrShape_MNK_ = Shape<_1, _1, _1> >
class CustomizedPipelineTmaUmmaAsync {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaAsync<Stages>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params,
ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1
if (cute::size(cluster_shape) > 1) {
multicast_consumer_arrival_count =
((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) *
params.num_consumers;
}
CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 &&
"Multicast consumer arrival count must be non-zero");
CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero");
cutlass::arch::detail::initialize_barrier_array_pair_aligned<
decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt,
multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape,
dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
} else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(
cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
// Constructor by default initializes barriers and calculates masks.
// These operations can be explicity deferred by specifying InitBarriers and InitMasks.
// If deferred, user code needs to guarantee init_masks and/or init_barriers is/are called.
template <typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE CustomizedPipelineTmaUmmaAsync(SharedStorage& storage, Params params,
ClusterShape cluster_shape, InitBarriers = {},
InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, InitMasks{}),
params_(params),
empty_barrier_ptr_(&storage.empty_barrier_[0]),
full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> ||
cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> ||
cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
////////////////////
// Producer APIs
////////////////////
// Four member functions are always used in pairs:
//
// * producer_try_acquire and producer_acquire, and
// * consumer_try_wait and consumer_wait.
//
// The two functions with "try" in their names are called "try" functions,
// and the other two are conceptually "finalize" functions.
// The "try" function in each pair starts the process of waiting on the barrier to flip.
// It opportunistically waits for an implementation-dependent timeout.
// Whether or not the barrier has flipped yet, the try function will return a token.
// If the token indicates that the barrier has not flipped,
// then the token must be passed into the corresponding "finalize" function.
// The finalize function will then block until the barrier has flipped.
// If the token indicates that the barrier _has_ flipped,
// then it is still correct to pass it into the finalize function.
// The finalize function will return immediately in that case.
CUTLASS_DEVICE
ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) {
return impl_.producer_try_acquire(state, skip_wait);
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state,
ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
impl_.producer_expect_transaction(state, transaction_bytes);
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(PipelineState state, uint32_t bytes) { impl_.producer_commit(state, bytes); }
// Prevents early exit of producer blocks in Cluster.
// This should be called once before kernel exits.
CUTLASS_DEVICE
void producer_tail(PipelineState state) { impl_.producer_tail(state); }
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
////////////////////
// Consumer APIs
////////////////////
CUTLASS_DEVICE
ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) {
return impl_.consumer_try_wait(state, skip_wait);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state,
ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void umma_consumer_release(PipelineState state) { umma_consumer_release(state.index(), false); }
CUTLASS_DEVICE
void consumer_release(PipelineState state) { impl_.consumer_release(state); }
private:
Impl impl_;
Params params_;
EmptyBarrier* empty_barrier_ptr_;
FullBarrier* full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notified.
CUTLASS_DEVICE
void umma_consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
// {$nv-release-never begin}
// TODO: Needs to be updated once Blackwell specialized pipeline is implemented.
// XMMA style bar_peek will be tested. We will need to revisit skip interface and
// what skip means when we have bar_peek functionality.
// A separate MR will implement MMA_2x1SM specialized pipeline.
// {$nv-release-never end}
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
} else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> && size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
} else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
} // namespace detail
} // namespace cutlass
#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_CUSTOMIZED_PIPELINE_CUH_
......@@ -459,7 +459,7 @@ void group_hadamard_transform_amax(const Tensor& input_, std::vector<Tensor*>& o
}
// Multi zero out multiple amaxes if needed
// Curretly don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// Currently don't support multi-launch when num_tensors is larger than kMaxTensorsPerKernel
// let the number of threads equal to number of tensors, use 1 block, kMaxTensorsPerKernel threads per block
dim3 block_setup_amax(kMaxTensorsPerKernel);
dim3 grid_setup_amax(1);
......
......@@ -29,7 +29,6 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
......@@ -129,7 +128,8 @@ template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TC, class CStride, class CSmemLayout,
class TSFC,
class TiledMMA,
bool kEnableStochasticRounding = false>
bool kEnableStochasticRounding = false,
bool kUseFastMath = false>
__global__ static
void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
......@@ -426,7 +426,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
static constexpr float fp4_max_inv = 1.0f / fp4_max;
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
......@@ -469,10 +475,13 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32.
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
if constexpr (!kUseFastMath) {
// Downcast to BF16 for bit-wise compatibility with unfused
// kernels
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
......@@ -481,14 +490,27 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
}
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales = cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
}
// Initialize RNG for tile
const size_t rng_sequence
......@@ -532,7 +554,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void
rht_gemm_ntt_w_sfc(int m, int n,
TA const* A,
......@@ -644,16 +666,15 @@ rht_gemm_ntt_w_sfc(int m, int n,
TC, decltype(dC), decltype(sC),
TSFC,
decltype(mma),
kEnableStochasticRounding>;
kEnableStochasticRounding,
kUseFastMath>;
bool status = cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size)
);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape,
......@@ -663,11 +684,12 @@ rht_gemm_ntt_w_sfc(int m, int n,
SFC,
mma, global_amax,
rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false, bool kUseFastMath = false>
void
rht_gemm_ttt_wrapper(int m, int n,
TA const* A,
......@@ -690,7 +712,7 @@ rht_gemm_ttt_wrapper(int m, int n,
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>(
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding, kUseFastMath>(
n, m,
A, B, C,
SFC, global_amax,
......@@ -800,20 +822,23 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding, kUseFastMath>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size);););
}
} // namespace transformer_engine
......
......@@ -270,6 +270,20 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_config, const size_t num_tensors,
cudaStream_t stream);
/*! \brief Casts grouped input tensor to quantized output tensors.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] outputs Output quantized tensors.
* \param[in] split_sections Split sections of the input tensor.
* \param[in] num_tensors Number of output tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -86,6 +86,43 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_columnwise(
const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
/*! Whether to enable fast math operations with reduced accuracy.
*
* Optimizations are kernel-specific and they may be applied
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
kNVTEQuantizationConfigNumAttributes
};
......@@ -997,6 +1003,12 @@ class QuantizationConfigWrapper {
&stochastic_rounding, sizeof(bool));
}
/*! \brief Set whether to enable fast math operations */
void set_use_fast_math(bool use_fast_math) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath,
&use_fast_math, sizeof(bool));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -857,9 +857,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
*size_written = attr_size;
if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided
if (buf == nullptr) {
......@@ -889,6 +890,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(buf, &config_.rng_state, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(buf, &config_.nvfp4_2d_quantization, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(buf, &config_.stochastic_rounding, attr_size);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(buf, &config_.use_fast_math, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......@@ -933,6 +946,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(&config_.use_fast_math, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -1501,7 +1501,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
}
}
// Restriction for the RHT cast fusion kernel.
// Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
......
......@@ -120,7 +120,7 @@ def get_align_size_for_quantization(recipe: Recipe) -> int:
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 64
return 128
return 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment