/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include #include #include #include #include "../common.h" #include "../transpose/cast_transpose.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" #include "cast_kernels.cuh" #include "dequantize_kernels.cuh" #include "math.h" #include "ptx.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize); using namespace transformer_engine; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; detail::quantize_helper(input, grad, output, dbias, workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); using namespace transformer_engine; // Create config with noop tensor QuantizationConfig quant_config; quant_config.noop_tensor = noop; nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); } void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_v2); using namespace transformer_engine; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; detail::quantize_helper( input, grad, output, dbias, workspace, quant_config, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; constexpr bool IS_ACT = false; constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; detail::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dsilu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; detail::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_drelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; detail::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dqgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; detail::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dsrelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; detail::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; detail::dequantize_helper(*reinterpret_cast(input), reinterpret_cast(output), stream); }