swiglu.cu 2.06 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "./activation_template.h"
8
9
10
#include "../util/math.h"


11
12
13
14
15
16
17
18
void nvte_swish(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
  NVTE_API_CALL(nvte_swish);
  using namespace transformer_engine;
  act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                         reinterpret_cast<Tensor*>(output),
                                         stream);
19
20
}

21
22
23
24
25
26
27
28
29
30
void nvte_dswish(const NVTETensor grad,
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_dswish);
  using namespace transformer_engine;
  dact_fn<fp32, Empty, dswish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
                                           *reinterpret_cast<const Tensor*>(input),
                                           reinterpret_cast<Tensor*>(output),
                                           stream);
31
32
33
34
35
36
37
}

void nvte_swiglu(const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
  NVTE_API_CALL(nvte_swiglu);
  using namespace transformer_engine;
38
39
40
  gated_act_fn<fp32, Empty, swish<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
                                               reinterpret_cast<Tensor*>(output),
                                               stream);
41
42
43
44
45
46
47
48
}

void nvte_dswiglu(const NVTETensor grad,
                  const NVTETensor input,
                  NVTETensor output,
                  cudaStream_t stream) {
  NVTE_API_CALL(nvte_dswiglu);
  using namespace transformer_engine;
49
50
51
52
53
  dgated_act_fn<fp32, Empty, swish<fp32, fp32>, dswish<fp32, fp32>>(
    *reinterpret_cast<const Tensor*>(grad),
    *reinterpret_cast<const Tensor*>(input),
    reinterpret_cast<Tensor*>(output),
    stream);
54
}