gelu.cu 3.25 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
 *
 * See LICENSE for license information.
 ************************************************************************/
6
#include "../util/math.h"
7
#include "./activation_template.h"
Przemek Tredak's avatar
Przemek Tredak committed
8

9
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
10
  NVTE_API_CALL(nvte_gelu);
Przemek Tredak's avatar
Przemek Tredak committed
11
  using namespace transformer_engine;
12
  act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
13
                                        reinterpret_cast<Tensor*>(output), stream);
14
15
}

16
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
17
18
19
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_dgelu);
  using namespace transformer_engine;
20
21
  dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
                                          *reinterpret_cast<const Tensor*>(input),
22
                                          reinterpret_cast<Tensor*>(output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
23
}
24

25
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
26
  NVTE_API_CALL(nvte_geglu);
27
  using namespace transformer_engine;
28
  gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
29
                                              reinterpret_cast<Tensor*>(output), stream);
30
31
}

32
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
33
                 cudaStream_t stream) {
34
  NVTE_API_CALL(nvte_dgeglu);
35
  using namespace transformer_engine;
36
  dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
37
38
      *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
      reinterpret_cast<Tensor*>(output), stream);
39
}
40

41
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
42
43
  NVTE_API_CALL(nvte_qgelu);
  using namespace transformer_engine;
44
  act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
45
                                         reinterpret_cast<Tensor*>(output), stream);
46
47
}

48
49
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                 cudaStream_t stream) {
50
51
  NVTE_API_CALL(nvte_dqgelu);
  using namespace transformer_engine;
52
  dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
53
54
                                           *reinterpret_cast<const Tensor*>(input),
                                           reinterpret_cast<Tensor*>(output), stream);
55
}
56

57
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
58
59
60
  NVTE_API_CALL(nvte_qgeglu);
  using namespace transformer_engine;
  gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
61
                                               reinterpret_cast<Tensor*>(output), stream);
62
63
}

64
65
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                  cudaStream_t stream) {
66
67
68
  NVTE_API_CALL(nvte_dqgeglu);
  using namespace transformer_engine;
  dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
69
70
      *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
      reinterpret_cast<Tensor*>(output), stream);
71
}