/************************************************************************* * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "../util/math.h" #include "./activation_template.h" void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; dgated_act_fn, drelu>( *reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_srelu); using namespace transformer_engine; act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); using namespace transformer_engine; dact_fn>(*reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; gated_act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; dgated_act_fn, dsrelu>( *reinterpret_cast(grad), *reinterpret_cast(input), reinterpret_cast(output), stream); }