gelu.cu 2.87 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
 *
 * See LICENSE for license information.
 ************************************************************************/
6
#include "./activation_template.h"
7
#include "../util/math.h"
Przemek Tredak's avatar
Przemek Tredak committed
8
9
10
11
12


void nvte_gelu(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
13
  NVTE_API_CALL(nvte_gelu);
Przemek Tredak's avatar
Przemek Tredak committed
14
  using namespace transformer_engine;
15
16
17
  act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                        reinterpret_cast<Tensor*>(output),
                                        stream);
18
19
20
21
22
23
24
25
}

void nvte_dgelu(const NVTETensor grad,
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_dgelu);
  using namespace transformer_engine;
26
27
28
29
  dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
                                          *reinterpret_cast<const Tensor*>(input),
                                          reinterpret_cast<Tensor*>(output),
                                          stream);
Przemek Tredak's avatar
Przemek Tredak committed
30
}
31
32
33
34

void nvte_geglu(const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
35
  NVTE_API_CALL(nvte_geglu);
36
  using namespace transformer_engine;
37
38
39
  gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                              reinterpret_cast<Tensor*>(output),
                                              stream);
40
41
42
43
44
45
}

void nvte_dgeglu(const NVTETensor grad,
                 const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
46
  NVTE_API_CALL(nvte_dgeglu);
47
  using namespace transformer_engine;
48
49
50
51
52
  dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
    *reinterpret_cast<const Tensor*>(grad),
    *reinterpret_cast<const Tensor*>(input),
    reinterpret_cast<Tensor*>(output),
    stream);
53
}
54
55
56
57
58
59

void nvte_qgelu(const NVTETensor input,
  NVTETensor output,
  cudaStream_t stream) {
  NVTE_API_CALL(nvte_qgelu);
  using namespace transformer_engine;
60
61
62
  act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                         reinterpret_cast<Tensor*>(output),
                                         stream);
63
64
65
66
67
68
69
70
}

void nvte_dqgelu(const NVTETensor grad,
   const NVTETensor input,
   NVTETensor output,
   cudaStream_t stream) {
  NVTE_API_CALL(nvte_dqgelu);
  using namespace transformer_engine;
71
72
73
74
  dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
                                                *reinterpret_cast<const Tensor*>(input),
                                                reinterpret_cast<Tensor*>(output),
                                                stream);
75
}