/************************************************************************* * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "./activation_template.h" #include "../util/math.h" void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; dgated_act_fn, dgelu>( *reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; dgated_act_fn, dqgelu>( *reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); }