utils.h 5.62 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#include "transformer_engine/fused_attn.h"
cyanguwa's avatar
cyanguwa committed
11
#include "transformer_engine/transformer_engine.h"
12
13

#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
14
#include <cudnn_frontend.h>
15
#include <cudnn_frontend_utils.h>
cyanguwa's avatar
cyanguwa committed
16

17
18
19
#include <cstdint>
#include <mutex>

cyanguwa's avatar
cyanguwa committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
    NVTE_Q_Matrix            = 0,  // queries
    NVTE_K_Matrix            = 1,  // keys
    NVTE_K_Matrix_Transpose  = 2,  // keys transposed
    NVTE_V_Matrix            = 3,  // values
    NVTE_V_Matrix_Transpose  = 4,  // value matrix transposed
    NVTE_S_Matrix            = 5,  // output of GEMM1
    NVTE_O_Matrix            = 6,  // final output
};

void generateMatrixStrides(
            int64_t b, int64_t h,
            int64_t s_q, int64_t s_kv,
            int64_t d, int64_t* strideA,
            NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id,
                                            int64_t const *dim,
                                            int64_t const *stride,
                                            bool is_virtual, bool is_value);

cudnn_frontend::Tensor tensor_create_with_offset(
                cudnnDataType_t type, int64_t id,
                int64_t const * dim, int64_t const * stride,
                bool is_virtual, bool is_value,
                std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);

cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type,
                                                    cudnnPointwiseMode_t mode);

cudnn_frontend::Operation unary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cudnn_frontend::Operation binary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
    cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cudnn_frontend::Operation ternary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
    cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cyanguwa's avatar
cyanguwa committed
71
72
73
74
75
76
77
78
79
80
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
81
82
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
83
  cudnnDataType_t tensor_type;
84
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
85
86
87
88

  bool operator<(const FADescriptor &rhs) const {
    return std::tie(b, h, s_q, s_kv, d,
                    attnScale, isTraining, dropoutProbability,
89
                    layout, mask_type, bias_type, tensor_type, use_workspace_opt)
90
91
92
93
                    < std::tie(
                      rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
                      rhs.attnScale, rhs.isTraining,
                      rhs.dropoutProbability, rhs.layout,
94
95
                      rhs.mask_type, rhs.bias_type,
                      rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
96
97
98
  }
};

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
  cudnn_frontend::DataType_t tensor_type;

  bool operator<(const FADescriptor_v1 &rhs) const {
    return std::tie(b, h, hg, s_q, s_kv, d,
                    attnScale, isTraining, dropoutProbability,
                    layout, mask_type, bias_type, tensor_type)
                    < std::tie(
                      rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
                      rhs.attnScale, rhs.isTraining,
                      rhs.dropoutProbability, rhs.layout,
                      rhs.mask_type, rhs.bias_type,
                      rhs.tensor_type);
  }
};

cyanguwa's avatar
cyanguwa committed
127
128
129
130
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
                int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
                int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);

131
132
133
134
135
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
                int32_t const * const q_cu_seqlens,
                int32_t const * const kv_cu_seqlens,
                int32_t *q_seqlens, int32_t *kv_seqlens);

cyanguwa's avatar
cyanguwa committed
136
137
138
}  // namespace fused_attn

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
139
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
cyanguwa's avatar
cyanguwa committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

class cudnnExecutionPlanManager {
 public:
    static cudnnExecutionPlanManager &Instance() {
        static thread_local cudnnExecutionPlanManager instance;
        return instance;
    }

    cudnnHandle_t GetCudnnHandle() {
        static thread_local std::once_flag flag;
        std::call_once(flag, [&] { cudnnCreate(&handle_); });
        return handle_;
    }

    ~cudnnExecutionPlanManager() {
        static thread_local std::once_flag flag;
        std::call_once(flag, [&] {
                        if (handle_ != nullptr) {
                          cudnnDestroy(handle_);
                        }});
    }

 private:
    cudnnHandle_t handle_ = nullptr;
};
}  // namespace transformer_engine

#endif