/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include #include #include "../common.h" #include "../transpose/cast_transpose.h" #include "../util/multi_stream.h" #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize); using namespace transformer_engine; constexpr bool IS_ACT = false; dispatch::quantize_fwd_helper(input, output, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; // Create config with noop tensor QuantizationConfig quant_config; quant_config.noop_tensor = noop; nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); } void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_v2); using namespace transformer_engine; constexpr bool IS_ACT = false; dispatch::quantize_fwd_helper(input, output, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; constexpr const NVTETensor activation_input = nullptr; dispatch::quantize_bwd_helper( input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_configs, const size_t num_tensors, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_quantize); using namespace transformer_engine; constexpr bool IS_ACT = false; const size_t num_streams = nvte_get_num_compute_streams(); int num_stream_used = std::min(num_streams, num_tensors); // wait for current stream to finish NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); for (int s = 0; s < num_stream_used; s++) { NVTE_CHECK_CUDA( cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); } for (int i = 0; i < num_tensors; i++) { dispatch::quantize_fwd_helper( inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); } // record events on compute streams for (int s = 0; s < num_stream_used; s++) { NVTE_CHECK_CUDA( cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); } // wait for all compute streams to finish for (int s = 0; s < num_stream_used; s++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); } } // Group quantize assumes contiguous inputs and outputs in memory allocation // TODO (zhongbo): find a better way to make it a more generalized API void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax); using namespace transformer_engine; constexpr bool IS_ACT = false; dispatch::group_quantize_fwd_helper(input, outputs, split_sections, num_tensors, quant_config, stream); }