swiglu.cu 2.05 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "./activation_template.h"
8
9
10
#include "../util/math.h"


11
void nvte_silu(const NVTETensor input,
12
13
               NVTETensor output,
               cudaStream_t stream) {
14
  NVTE_API_CALL(nvte_silu);
15
  using namespace transformer_engine;
16
  act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
17
18
                                         reinterpret_cast<Tensor*>(output),
                                         stream);
19
20
}

21
void nvte_dsilu(const NVTETensor grad,
22
23
24
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
25
  NVTE_API_CALL(nvte_dsilu);
26
  using namespace transformer_engine;
27
  dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
28
29
30
                                           *reinterpret_cast<const Tensor*>(input),
                                           reinterpret_cast<Tensor*>(output),
                                           stream);
31
32
33
34
35
36
37
}

void nvte_swiglu(const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
  NVTE_API_CALL(nvte_swiglu);
  using namespace transformer_engine;
38
  gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
39
40
                                               reinterpret_cast<Tensor*>(output),
                                               stream);
41
42
43
44
45
46
47
48
}

void nvte_dswiglu(const NVTETensor grad,
                  const NVTETensor input,
                  NVTETensor output,
                  cudaStream_t stream) {
  NVTE_API_CALL(nvte_dswiglu);
  using namespace transformer_engine;
49
  dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(
50
51
52
53
    *reinterpret_cast<const Tensor*>(grad),
    *reinterpret_cast<const Tensor*>(input),
    reinterpret_cast<Tensor*>(output),
    stream);
54
}