/************************************************************************* * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "../util/math.h" #include "./activation_template.h" void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; act_fn>(input, output, stream); } void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_gelu); using namespace transformer_engine; constexpr bool IS_ACT = true; dispatch::group_quantize_fwd_helper>(input, output, nullptr, stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; dact_fn>(grad, input, output, stream); } void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dgelu); using namespace transformer_engine; NVTETensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = true; dispatch::group_quantize_bwd_helper>( grad, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; dispatch::quantize_bwd_helper>( input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; dispatch::group_quantize_bwd_helper>( input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; Empty e = {}; gated_act_fn>(input, output, e, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; Empty e = {}; dgated_act_fn, dgelu>(grad, input, output, e, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; act_fn>(input, output, stream); } void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_qgelu); using namespace transformer_engine; constexpr bool IS_ACT = true; dispatch::group_quantize_fwd_helper>(input, output, nullptr, stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; dact_fn>(grad, input, output, stream); } void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; NVTETensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = true; dispatch::group_quantize_bwd_helper>( grad, input, output, dbias, workspace, nullptr, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias_dqgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; dispatch::quantize_bwd_helper>( input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); using namespace transformer_engine; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; dispatch::group_quantize_bwd_helper>( input, activation_input, output, dbias, workspace, nullptr, stream); } void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; Empty e = {}; gated_act_fn>(input, output, e, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; Empty e = {}; dgated_act_fn, dqgelu>(grad, input, output, e, stream); }