cast_transpose.h 2.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_

#include "../common.h"

namespace transformer_engine::detail {

void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP,
          ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output,
                          Tensor *dbias, Tensor *workspace, cudaStream_t stream);

template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
          ComputeType (*OP2)(ComputeType, const ParamOP &)>
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
                               cudaStream_t stream);

26
27
28
29
30
31
32
33
34
35
36
37
void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
                                         SimpleTensor &scale_inv_t, SimpleTensor &output,
                                         SimpleTensor &output_t, const float epsilon,
                                         const bool return_transpose, const bool pow_2_scale,
                                         cudaStream_t stream);

void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
                                         SimpleTensor &scale_inv_t, SimpleTensor &output,
                                         SimpleTensor &output_t, const float epsilon,
                                         const bool return_transpose, const bool pow_2_scale,
                                         cudaStream_t stream);

38
39
40
}  // namespace transformer_engine::detail

#endif  // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_