activation_template.h 2.76 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
10
11
12
13
/*! \file activation_template.h
 *  \brief Activation functions template.
 */

#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_

14
#include <cuda_runtime.h>
15
#include <transformer_engine/activation.h>
16

17
#include "../common.h"
18
19
20
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
21
#include "../util/vectorized_pointwise.h"
22
23
24

namespace transformer_engine {

25
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
26
27
28
29
30
31
32
33
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = true;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
34

35
36
  quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
                                                        workspace, stream);
37
38
}

39
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
40
41
42
43
44
45
46
47
void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
             cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
48

49
50
  quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
                                                        workspace, stream);
51
52
}

53
54
55
56
57
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_DGATED = false;
  constexpr NVTETensor grad = nullptr;
58

59
  quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
60
61
}

62
63
64
65
66
67
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
          ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                   cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_DGATED = true;
68

69
  quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
70
71
72
}

}  // namespace transformer_engine
73
74

#endif  // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_