relu.cu 3.25 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "../util/math.h"
8
#include "./activation_template.h"
9

10
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
11
12
  NVTE_API_CALL(nvte_relu);
  using namespace transformer_engine;
13
  act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
14
                                        reinterpret_cast<Tensor*>(output), stream);
15
16
}

17
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
18
19
20
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_drelu);
  using namespace transformer_engine;
21
  dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
22
23
                                          *reinterpret_cast<const Tensor*>(input),
                                          reinterpret_cast<Tensor*>(output), stream);
24
25
}

26
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
27
28
  NVTE_API_CALL(nvte_reglu);
  using namespace transformer_engine;
29
  gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
30
                                              reinterpret_cast<Tensor*>(output), stream);
31
32
}

33
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
34
35
36
                 cudaStream_t stream) {
  NVTE_API_CALL(nvte_dreglu);
  using namespace transformer_engine;
37
  dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
38
39
      *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
      reinterpret_cast<Tensor*>(output), stream);
40
}
41

42
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
43
44
45
  NVTE_API_CALL(nvte_srelu);
  using namespace transformer_engine;
  act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
46
                                         reinterpret_cast<Tensor*>(output), stream);
47
48
}

49
50
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                 cudaStream_t stream) {
51
52
53
  NVTE_API_CALL(nvte_dsrelu);
  using namespace transformer_engine;
  dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
54
55
                                           *reinterpret_cast<const Tensor*>(input),
                                           reinterpret_cast<Tensor*>(output), stream);
56
57
}

58
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
59
60
61
  NVTE_API_CALL(nvte_sreglu);
  using namespace transformer_engine;
  gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
62
                                               reinterpret_cast<Tensor*>(output), stream);
63
64
}

65
66
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                  cudaStream_t stream) {
67
68
69
  NVTE_API_CALL(nvte_dsreglu);
  using namespace transformer_engine;
  dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
70
71
      *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
      reinterpret_cast<Tensor*>(output), stream);
72
}