relu.cu 1.96 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "./activation_template.h"
8
9
10
11
12
13
14
15
#include "../util/math.h"


void nvte_relu(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
  NVTE_API_CALL(nvte_relu);
  using namespace transformer_engine;
16
17
18
  act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                        reinterpret_cast<Tensor*>(output),
                                        stream);
19
20
21
22
23
24
25
26
}

void nvte_drelu(const NVTETensor grad,
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_drelu);
  using namespace transformer_engine;
27
28
29
30
  dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
                                         *reinterpret_cast<const Tensor*>(input),
                                         reinterpret_cast<Tensor*>(output),
                                         stream);
31
32
33
34
35
36
37
}

void nvte_reglu(const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_reglu);
  using namespace transformer_engine;
38
  gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
39
40
41
42
43
44
45
46
47
48
        reinterpret_cast<Tensor*>(output),
        stream);
}

void nvte_dreglu(const NVTETensor grad,
                 const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
  NVTE_API_CALL(nvte_dreglu);
  using namespace transformer_engine;
49
50
51
52
53
  dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
    *reinterpret_cast<const Tensor*>(grad),
    *reinterpret_cast<const Tensor*>(input),
    reinterpret_cast<Tensor*>(output),
    stream);
54
}