activation_template.h 5.55 KB
Newer Older
1
2
3
4
5
6
7
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cuda_runtime.h>
8
#include <transformer_engine/activation.h>
9

10
11
#include "../common.h"
#include "../util/vectorized_pointwise.h"
12
13
14

namespace transformer_engine {

15
16
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
17
18
19
20
21
  CheckInputTensor(input, "act_lu_input");
  CheckOutputTensor(*output, "act_lu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  const size_t tot_elts = product(input.data.shape);

22
23
24
25
26
27
28
29
30
31
32
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          VectorizedUnaryKernelLauncher<nvec, Param, OP>(
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr),
              reinterpret_cast<const ComputeType *>(output->scale.dptr),
              reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
33
34
}

35
36
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
37
38
39
40
  CheckInputTensor(input, "dact_lu_input");
  CheckInputTensor(grad, "dact_lu_input_grad");
  CheckOutputTensor(*output, "dact_lu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
41
  NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match.");
42
43
  const size_t tot_elts = product(input.data.shape);

44
45
46
47
48
49
50
51
52
53
54
55
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
              reinterpret_cast<const IType *>(grad.data.dptr),
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr),
              reinterpret_cast<const ComputeType *>(output->scale.dptr),
              reinterpret_cast<ComputeType *>(output->amax.dptr), tot_elts, {},
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
56
57
}

58
59
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
60
61
62
63
64
65
66
67
68
  CheckInputTensor(input, "gated_act_input");
  CheckOutputTensor(*output, "gated_act_output");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
             "Input shape[0] must be equal to output shape[0].");
  NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
             "Input shape[1] must be 2x larger than output shape[1].");

69
70
71
72
73
74
75
76
77
78
79
80
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr),
              reinterpret_cast<const ComputeType *>(output->scale.dptr),
              reinterpret_cast<ComputeType *>(output->amax.dptr), output->data.shape[0],
              output->data.shape[1], {},
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
81
82
}

83
84
85
template <typename ComputeType, typename Param, ComputeType (*OP1)(ComputeType, const Param &),
          ComputeType (*OP2)(ComputeType, const Param &)>
void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
86
87
88
89
90
91
92
93
94
95
  CheckInputTensor(grad, "dgated_act_grad");
  CheckInputTensor(input, "dgated_act_input");
  CheckOutputTensor(*output, "dgated_act_output");
  NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
             "Output shape[0] must be equal to grad shape[0].");
  NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
             "Output shape[1] must be 2x larger than grad shape[1].");
96
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
97

98
99
100
101
102
103
104
105
106
107
108
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
              reinterpret_cast<const IType *>(grad.data.dptr),
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
              {},
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
109
110
111
}

}  // namespace transformer_engine