swiglu.cu 1.78 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "../util/math.h"
8
#include "./activation_template.h"
9

10
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
11
  NVTE_API_CALL(nvte_silu);
12
  using namespace transformer_engine;
13
  act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
14
                                        reinterpret_cast<Tensor*>(output), stream);
15
16
}

17
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
18
                cudaStream_t stream) {
19
  NVTE_API_CALL(nvte_dsilu);
20
  using namespace transformer_engine;
21
  dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
22
23
                                          *reinterpret_cast<const Tensor*>(input),
                                          reinterpret_cast<Tensor*>(output), stream);
24
25
}

26
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
27
28
  NVTE_API_CALL(nvte_swiglu);
  using namespace transformer_engine;
29
  gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
30
                                              reinterpret_cast<Tensor*>(output), stream);
31
32
}

33
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
34
35
36
                  cudaStream_t stream) {
  NVTE_API_CALL(nvte_dswiglu);
  using namespace transformer_engine;
37
  dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(
38
39
      *reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
      reinterpret_cast<Tensor*>(output), stream);
40
}