swiglu.cu 4.13 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "../util/math.h"
8
#include "./activation_template.h"
9

10
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
11
  NVTE_API_CALL(nvte_silu);
12
  using namespace transformer_engine;
13
  act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
14
15
}

16
17
18
19
20
21
22
23
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_group_silu);
  using namespace transformer_engine;
  constexpr bool IS_ACT = true;
  dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
                                                                       stream);
}

24
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
25
                cudaStream_t stream) {
26
  NVTE_API_CALL(nvte_dsilu);
27
  using namespace transformer_engine;
28
  dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
29
30
}

31
32
33
34
35
36
37
38
39
40
41
42
43
44
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
                      NVTEGroupedTensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_group_dsilu);
  using namespace transformer_engine;
  NVTETensor dbias = nullptr;
  NVTETensor workspace = nullptr;

  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = true;

  dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
      grad, input, output, dbias, workspace, nullptr, stream);
}

45
46
47
48
49
50
51
52
53
54
55
56
57
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsilu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;

  dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
      input, activation_input, output, dbias, workspace, nullptr, stream);
}

58
59
60
61
62
63
64
65
66
67
68
69
70
71
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
                                     const NVTEGroupedTensor activation_input,
                                     NVTEGroupedTensor output, NVTETensor dbias,
                                     NVTETensor workspace, cudaStream_t stream) {
  NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;

  dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
      input, activation_input, output, dbias, workspace, nullptr, stream);
}

72
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
73
74
  NVTE_API_CALL(nvte_swiglu);
  using namespace transformer_engine;
75
76
  Empty e = {};
  gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
77
78
}

79
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
80
81
82
                  cudaStream_t stream) {
  NVTE_API_CALL(nvte_dswiglu);
  using namespace transformer_engine;
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  Empty e = {};
  dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
}

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
                         cudaStream_t stream) {
  NVTE_API_CALL(nvte_clamped_swiglu);
  using namespace transformer_engine;
  ClampedSwiGLUParam param = {limit, alpha};
  gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                          float limit, float alpha, cudaStream_t stream) {
  NVTE_API_CALL(nvte_clamped_dswiglu);
  using namespace transformer_engine;
  ClampedSwiGLUParam param = {limit, alpha};
  dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
      grad, input, output, param, stream);
102
}