activation_template.h 2.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
10
11
12
13
/*! \file activation_template.h
 *  \brief Activation functions template.
 */

#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_

14
#include <cuda_runtime.h>
15
#include <transformer_engine/activation.h>
16

17
18
#include "../cast/dispatch/gated.cuh"
#include "../cast/dispatch/quantize.cuh"
19
#include "../common.h"
20
21
22

namespace transformer_engine {

23
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
24
25
26
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_ACT = true;
27
  dispatch::quantize_fwd_helper<IS_ACT, Empty, OP>(input, output, nullptr, stream);
28
29
}

30
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
31
32
33
34
35
36
37
void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
             cudaStream_t stream) {
  using namespace detail;
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = true;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
38

39
40
  dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, OP>(grad, input, output, dbias, workspace,
                                                              nullptr, stream);
41
42
}

43
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
44
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
45
  using namespace detail;
46
  dispatch::quantize_gated_fwd_helper<Param, ActOP>(input, output, p, stream);
47
48
}

49
50
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
          ComputeType (*DActOP)(ComputeType, const Param &)>
51
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
52
53
                   cudaStream_t stream) {
  using namespace detail;
54
  dispatch::quantize_gated_bwd_helper<Param, ActOP, DActOP>(grad, input, output, p, stream);
55
56
57
}

}  // namespace transformer_engine
58
59

#endif  // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_